Emeritus-21's picture
Update app.py
ec8d132 verified
import os
import time
from threading import Thread
import gradio as gr
import spaces
from PIL import Image
import torch
from transformers import (
AutoProcessor,
AutoModelForImageTextToText,
Qwen2_5_VLForConditionalGeneration,
TextIteratorStreamer,
)
from flask import Flask, request, jsonify
# ---------------------------
# Models
# ---------------------------
MODEL_PATHS = {
"Model 1 (Qwen2.5-VL-7B)": "Qwen/Qwen2.5-VL-7B-Instruct",
"Model 2 (Nanonets-OCR-s)": "nanonets/ocr_model",
"Model 3 (Finetuned HTR)": "your-finetuned-htr-model"
}
models = {}
processors = {}
for name, path in MODEL_PATHS.items():
print(f"Loading {name} ...")
processors[name] = AutoProcessor.from_pretrained(path, trust_remote_code=True)
models[name] = AutoModelForImageTextToText.from_pretrained(
path, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True
)
# ---------------------------
# Inference Function
# ---------------------------
def run_inference(model_name, image, prompt=None):
processor = processors[model_name]
model = models[model_name]
inputs = processor(text=prompt or "", images=image, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=512)
return processor.batch_decode(outputs, skip_special_tokens=True)[0]
# ---------------------------
# Flask API
# ---------------------------
app = Flask(__name__)
@app.route("/predict", methods=["POST"])
def predict():
data = request.json
model_name = data.get("model_name", "Model 1 (Qwen2.5-VL-7B)")
prompt = data.get("prompt", "")
image_b64 = data.get("image") # frontend must send base64 image
if not image_b64:
return jsonify({"error": "No image provided"}), 400
import base64
from io import BytesIO
image_bytes = base64.b64decode(image_b64)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
result = run_inference(model_name, image, prompt)
return jsonify({"result": result})
# ---------------------------
# Gradio UI
# ---------------------------
def gradio_infer(image, model_choice, prompt):
return run_inference(model_choice, image, prompt)
with gr.Blocks() as demo:
gr.Markdown("## ✍ Handwritten OCR with Multiple Models")
with gr.Row():
with gr.Column():
model_choice = gr.Dropdown(list(MODEL_PATHS.keys()), label="Select Model")
prompt = gr.Textbox(label="Custom Prompt (optional)")
img_input = gr.Image(type="pil", label="Upload Image")
btn = gr.Button("Run OCR")
with gr.Column():
output = gr.Textbox(label="OCR Output")
btn.click(gradio_infer, [img_input, model_choice, prompt], output)
# ---------------------------
# Run both Flask + Gradio
# ---------------------------
def run_flask():
app.run(host="0.0.0.0", port=7861)
if __name__ == "__main__":
Thread(target=run_flask).start() # Start Flask API
demo.launch(server_name="0.0.0.0", server_port=7860) # Start Gradio UI