Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import json | |
from io import BytesIO | |
import base64 | |
import torch | |
from tempfile import gettempdir | |
from os import path, makedirs, remove | |
import models | |
import time | |
def get_safe_cache_dir(): | |
try: | |
# Thử ghi vào ~/.cache/huggingface (nếu có) | |
default_cache = path.expanduser("~/.cache/huggingface") | |
makedirs(default_cache, exist_ok=True) | |
test_file = path.join(default_cache, "test_write.txt") | |
with open(test_file, "w") as f: | |
f.write("ok") | |
remove(test_file) | |
return default_cache | |
except Exception: | |
# Nếu lỗi (ví dụ trên HuggingFace Spaces), dùng temp | |
return path.join(gettempdir(), "huggingface") | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
CACHE_DIR = get_safe_cache_dir() | |
AVAILABLE_MODELS = { | |
# "TrOCR (Base Printed)": { | |
# "id": "microsoft/trocr-base-printed", | |
# "type": "trocr" | |
# }, | |
"EraX (VL-2B-V1.5)": { | |
"id": "erax-ai/EraX-VL-2B-V1.5", | |
"type": "erax" | |
} | |
} | |
_model_cache = {} | |
print("Using device:", DEVICE) | |
print("Cache directory:", CACHE_DIR) | |
def load_model(model_key): | |
print("Processing image with model:", model_key) | |
model_id = AVAILABLE_MODELS[model_key]["id"] | |
model_type = AVAILABLE_MODELS[model_key]["type"] | |
print("Model ID:", model_id, "Type:", model_type) | |
if model_id in _model_cache: | |
return _model_cache[model_key] | |
if "trocr" == model_type: | |
model = models.TrOCRModel(model_id, cache_dir=CACHE_DIR, device=DEVICE) | |
elif "erax" == model_type: | |
model = models.EraXModel(model_id, cache_dir=CACHE_DIR, device=DEVICE) | |
else: | |
raise ValueError("Unknown model") | |
_model_cache[model_key] = model | |
print('Load model:', model_id, ' successfully!') | |
return model | |
# Hàm xử lý ảnh đầu vào | |
def gradio_process(image: Image.Image, model_key: str): | |
if image is None: | |
return {"error": "No image provided"} | |
print('Received image size:', image.size) | |
start = time.time() | |
model = load_model(model_key) | |
result = model.predict(image) | |
print('Model predicted successfully!') | |
print('Result:', result) | |
print('Time taken for prediction:', time.time() - start) | |
return json.dumps({ | |
"texts": result, | |
"image_size": { | |
"width": image.width, | |
"height": image.height | |
}, | |
"mode": image.mode, | |
}, indent=4) | |
# Giao diện Gradio | |
demo = gr.Interface( | |
fn=gradio_process, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), | |
gr.Dropdown(choices=list(AVAILABLE_MODELS.keys()), label="Chọn mô hình", value="TrOCR (Base Printed)"), | |
# gr.Textbox(label="Prompt (chỉ dùng cho EraX)", placeholder="Ảnh này có gì?") | |
], | |
outputs=gr.JSON(label="Output (Text/JSON Extract)"), | |
title="Image to Text/JSON Extractor", | |
description="Upload an image and extract structured text using OCR." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |