diabolic6045's picture
Update app.py
c9f8b16 verified
#!/usr/bin/env python3
"""
Gradio app for Sanskrit text transcription using Qwen2.5-VL model
Based on quick_test_improved.py
"""
import gradio as gr
import torch
import base64
import io
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import os
import logging
import spaces
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load model at module level (global scope)
model_path = 'diabolic6045/Sanskrit-Qwen2.5-VL-7B-Instruct-OCR'
logger.info("Loading processor...")
processor = AutoProcessor.from_pretrained(model_path)
logger.info("Loading Sanskrit OCR model...")
# Check if CUDA is available, otherwise use CPU
device_map = "auto" if torch.cuda.is_available() else "cpu"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map=device_map
)
model.eval()
device = next(model.parameters()).device
logger.info(f"Model loaded on device: {device}")
def check_model_status():
"""Check if model is loaded and ready"""
try:
if model is not None and processor is not None:
return "βœ… Model loaded and ready"
else:
return "⏳ Model not loaded yet"
except Exception as e:
return f"❌ Model error: {str(e)}"
@spaces.GPU
def transcribe_sanskrit(image, custom_prompt, progress=gr.Progress()):
"""Gradio interface function for transcription using pre-loaded model"""
if image is None:
return "Please upload an image first."
try:
progress(0.1, desc="Processing image...")
# Use custom prompt if provided, otherwise use default
prompt = custom_prompt if custom_prompt.strip() else "Please transcribe the Sanskrit text shown in this image:"
# Format the conversation using chat template
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}
]
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Get model device and move inputs there
model_device = next(model.parameters()).device
inputs = {k: v.to(model_device) for k, v in inputs.items()}
progress(0.5, desc="Generating transcription...")
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
pad_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
repetition_penalty=1.1
)
# Extract only the generated part
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
progress(1.0, desc="Complete!")
return output_text[0] if output_text else ""
except Exception as e:
logger.error(f"Error in transcribe_sanskrit: {e}")
return f"❌ Error occurred: {str(e)}\n\nPlease try again or check if the model files are properly loaded."
def create_gradio_interface():
"""Create and configure the Gradio interface"""
with gr.Blocks(
title="Sanskrit Text Transcription",
theme=gr.themes.Soft()
) as app:
gr.HTML("""
<div class="main-header">
<h1>πŸ•‰οΈ Sanskrit Text Transcription</h1>
<p>Upload an image containing Sanskrit text and get an accurate transcription using the specialized Sanskrit OCR model</p>
<p><strong>πŸš€ Powered by ZeroGPU:</strong> Dynamic GPU allocation for efficient processing</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Upload Image")
image_input = gr.Image(
type="pil",
label="Sanskrit Text Image",
height=400
)
gr.Markdown("### Custom Prompt (Optional)")
custom_prompt = gr.Textbox(
label="Custom transcription prompt",
placeholder="Please transcribe the Sanskrit text shown in this image:",
lines=2,
value="Please transcribe the Sanskrit text shown in this image:"
)
transcribe_btn = gr.Button(
"πŸ•‰οΈ Transcribe Sanskrit Text",
variant="primary",
size="lg"
)
gr.Markdown("""
### Instructions:
1. Upload an image containing Sanskrit text
2. Optionally modify the prompt for better results
3. Click the transcribe button
4. View the transcribed text below
""")
with gr.Column(scale=1):
gr.Markdown("### Transcription Result")
output_text = gr.Textbox(
label="Transcribed Sanskrit Text",
lines=10,
max_lines=20,
show_copy_button=True
)
gr.Markdown("### Model Information")
model_status = gr.Textbox(
label="Model Status",
value="Checking...",
interactive=False
)
check_status_btn = gr.Button("πŸ”„ Check Model Status", size="sm")
gr.Markdown("""
**Model:** diabolic6045/Sanskrit-Qwen2.5-VL-7B-Instruct-OCR
**Features:**
- Multimodal vision-language model
- Pre-trained specifically for Sanskrit OCR
- Supports various Sanskrit scripts
- High accuracy Sanskrit text transcription
""")
# Event handlers
transcribe_btn.click(
fn=transcribe_sanskrit,
inputs=[image_input, custom_prompt],
outputs=output_text,
show_progress=True
)
# Auto-transcribe when image is uploaded
image_input.change(
fn=transcribe_sanskrit,
inputs=[image_input, custom_prompt],
outputs=output_text,
show_progress=True
)
# Model status check
check_status_btn.click(
fn=check_model_status,
outputs=model_status
)
# Check model status on app load
app.load(
fn=check_model_status,
outputs=model_status
)
return app
def main():
"""Main function to launch the Gradio app"""
logger.info("Starting Sanskrit Transcription Gradio App...")
# Create the interface
app = create_gradio_interface()
# Launch the app
app.launch(
server_name="0.0.0.0", # Allow external access
server_port=7860, # Default Gradio port
share=False, # Enable request queuing
max_threads=4 # Limit concurrent requests
)
if __name__ == "__main__":
main()