Spaces:
Runtime error
Runtime error
import os | |
import time | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
from PIL import Image | |
import torch | |
from transformers import ( | |
AutoProcessor, | |
AutoModelForImageTextToText, | |
Qwen2_5_VLForConditionalGeneration, | |
TextIteratorStreamer, | |
) | |
from flask import Flask, request, jsonify | |
# --------------------------- | |
# Models | |
# --------------------------- | |
MODEL_PATHS = { | |
"Model 1 (Qwen2.5-VL-7B)": "Qwen/Qwen2.5-VL-7B-Instruct", | |
"Model 2 (Nanonets-OCR-s)": "nanonets/ocr_model", | |
"Model 3 (Finetuned HTR)": "your-finetuned-htr-model" | |
} | |
models = {} | |
processors = {} | |
for name, path in MODEL_PATHS.items(): | |
print(f"Loading {name} ...") | |
processors[name] = AutoProcessor.from_pretrained(path, trust_remote_code=True) | |
models[name] = AutoModelForImageTextToText.from_pretrained( | |
path, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True | |
) | |
# --------------------------- | |
# Inference Function | |
# --------------------------- | |
def run_inference(model_name, image, prompt=None): | |
processor = processors[model_name] | |
model = models[model_name] | |
inputs = processor(text=prompt or "", images=image, return_tensors="pt").to("cuda") | |
outputs = model.generate(**inputs, max_new_tokens=512) | |
return processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
# --------------------------- | |
# Flask API | |
# --------------------------- | |
app = Flask(__name__) | |
def predict(): | |
data = request.json | |
model_name = data.get("model_name", "Model 1 (Qwen2.5-VL-7B)") | |
prompt = data.get("prompt", "") | |
image_b64 = data.get("image") # frontend must send base64 image | |
if not image_b64: | |
return jsonify({"error": "No image provided"}), 400 | |
import base64 | |
from io import BytesIO | |
image_bytes = base64.b64decode(image_b64) | |
image = Image.open(BytesIO(image_bytes)).convert("RGB") | |
result = run_inference(model_name, image, prompt) | |
return jsonify({"result": result}) | |
# --------------------------- | |
# Gradio UI | |
# --------------------------- | |
def gradio_infer(image, model_choice, prompt): | |
return run_inference(model_choice, image, prompt) | |
with gr.Blocks() as demo: | |
gr.Markdown("## ✍ Handwritten OCR with Multiple Models") | |
with gr.Row(): | |
with gr.Column(): | |
model_choice = gr.Dropdown(list(MODEL_PATHS.keys()), label="Select Model") | |
prompt = gr.Textbox(label="Custom Prompt (optional)") | |
img_input = gr.Image(type="pil", label="Upload Image") | |
btn = gr.Button("Run OCR") | |
with gr.Column(): | |
output = gr.Textbox(label="OCR Output") | |
btn.click(gradio_infer, [img_input, model_choice, prompt], output) | |
# --------------------------- | |
# Run both Flask + Gradio | |
# --------------------------- | |
def run_flask(): | |
app.run(host="0.0.0.0", port=7861) | |
if __name__ == "__main__": | |
Thread(target=run_flask).start() # Start Flask API | |
demo.launch(server_name="0.0.0.0", server_port=7860) # Start Gradio UI | |