ocr-detection / app.py
vungocthach1112's picture
update model
e5254aa
import gradio as gr
from PIL import Image
import json
from io import BytesIO
import base64
import torch
from tempfile import gettempdir
from os import path, makedirs, remove
import models
import time
def get_safe_cache_dir():
try:
# Thử ghi vào ~/.cache/huggingface (nếu có)
default_cache = path.expanduser("~/.cache/huggingface")
makedirs(default_cache, exist_ok=True)
test_file = path.join(default_cache, "test_write.txt")
with open(test_file, "w") as f:
f.write("ok")
remove(test_file)
return default_cache
except Exception:
# Nếu lỗi (ví dụ trên HuggingFace Spaces), dùng temp
return path.join(gettempdir(), "huggingface")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CACHE_DIR = get_safe_cache_dir()
AVAILABLE_MODELS = {
# "TrOCR (Base Printed)": {
# "id": "microsoft/trocr-base-printed",
# "type": "trocr"
# },
"EraX (VL-2B-V1.5)": {
"id": "erax-ai/EraX-VL-2B-V1.5",
"type": "erax"
}
}
_model_cache = {}
print("Using device:", DEVICE)
print("Cache directory:", CACHE_DIR)
def load_model(model_key):
print("Processing image with model:", model_key)
model_id = AVAILABLE_MODELS[model_key]["id"]
model_type = AVAILABLE_MODELS[model_key]["type"]
print("Model ID:", model_id, "Type:", model_type)
if model_id in _model_cache:
return _model_cache[model_key]
if "trocr" == model_type:
model = models.TrOCRModel(model_id, cache_dir=CACHE_DIR, device=DEVICE)
elif "erax" == model_type:
model = models.EraXModel(model_id, cache_dir=CACHE_DIR, device=DEVICE)
else:
raise ValueError("Unknown model")
_model_cache[model_key] = model
print('Load model:', model_id, ' successfully!')
return model
# Hàm xử lý ảnh đầu vào
def gradio_process(image: Image.Image, model_key: str):
if image is None:
return {"error": "No image provided"}
print('Received image size:', image.size)
start = time.time()
model = load_model(model_key)
result = model.predict(image)
print('Model predicted successfully!')
print('Result:', result)
print('Time taken for prediction:', time.time() - start)
return json.dumps({
"texts": result,
"image_size": {
"width": image.width,
"height": image.height
},
"mode": image.mode,
}, indent=4)
# Giao diện Gradio
demo = gr.Interface(
fn=gradio_process,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Dropdown(choices=list(AVAILABLE_MODELS.keys()), label="Chọn mô hình", value="TrOCR (Base Printed)"),
# gr.Textbox(label="Prompt (chỉ dùng cho EraX)", placeholder="Ảnh này có gì?")
],
outputs=gr.JSON(label="Output (Text/JSON Extract)"),
title="Image to Text/JSON Extractor",
description="Upload an image and extract structured text using OCR."
)
if __name__ == "__main__":
demo.launch()