|
|
from functools import partial |
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
import spaces |
|
|
import os |
|
|
import tempfile |
|
|
from PIL import Image, ImageDraw |
|
|
import re |
|
|
|
|
|
|
|
|
print("Loading model and tokenizer...") |
|
|
model_name = "deepseek-ai/DeepSeek-OCR" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
model = AutoModel.from_pretrained( |
|
|
model_name, |
|
|
_attn_implementation="flash_attention_2", |
|
|
trust_remote_code=True, |
|
|
use_safetensors=True, |
|
|
) |
|
|
model = model.eval() |
|
|
print("β
Model loaded successfully.") |
|
|
|
|
|
|
|
|
def find_result_image(path): |
|
|
for filename in os.listdir(path): |
|
|
if "grounding" in filename or "result" in filename: |
|
|
try: |
|
|
image_path = os.path.join(path, filename) |
|
|
return Image.open(image_path) |
|
|
except Exception as e: |
|
|
print(f"Error opening result image {filename}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def process_ocr_task(image, model_size, ref_text, task_type): |
|
|
""" |
|
|
Processes an image with DeepSeek-OCR for all supported tasks. |
|
|
Now draws ALL detected bounding boxes for ANY task. |
|
|
""" |
|
|
if image is None: |
|
|
return "Please upload an image first.", None |
|
|
|
|
|
print("π Moving model to GPU...") |
|
|
model_gpu = model.cuda().to(torch.bfloat16) |
|
|
print("β
Model is on GPU.") |
|
|
|
|
|
with tempfile.TemporaryDirectory() as output_path: |
|
|
|
|
|
if task_type == "π Free OCR": |
|
|
prompt = "<image>\nFree OCR." |
|
|
elif task_type == "π Convert to Markdown": |
|
|
prompt = "<image>\n<|grounding|>Convert the document to markdown." |
|
|
elif task_type == "π Parse Figure": |
|
|
prompt = "<image>\nParse the figure." |
|
|
elif task_type == "π Locate Object by Reference": |
|
|
if not ref_text or ref_text.strip() == "": |
|
|
raise gr.Error("For the 'Locate' task, you must provide the reference text to find!") |
|
|
prompt = f"<image>\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image." |
|
|
else: |
|
|
prompt = "<image>\nFree OCR." |
|
|
|
|
|
temp_image_path = os.path.join(output_path, "temp_image.png") |
|
|
image.save(temp_image_path) |
|
|
|
|
|
|
|
|
size_configs = { |
|
|
"Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, |
|
|
"Small": {"base_size": 640, "image_size": 640, "crop_mode": False}, |
|
|
"Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, |
|
|
"Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, |
|
|
"Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True}, |
|
|
} |
|
|
config = size_configs.get(model_size, size_configs["Gundam (Recommended)"]) |
|
|
|
|
|
print(f"π Running inference with prompt: {prompt}") |
|
|
text_result = model_gpu.infer( |
|
|
tokenizer, |
|
|
prompt=prompt, |
|
|
image_file=temp_image_path, |
|
|
output_path=output_path, |
|
|
base_size=config["base_size"], |
|
|
image_size=config["image_size"], |
|
|
crop_mode=config["crop_mode"], |
|
|
save_results=True, |
|
|
test_compress=True, |
|
|
eval_mode=True, |
|
|
) |
|
|
|
|
|
print(f"====\nπ Text Result: {text_result}\n====") |
|
|
|
|
|
|
|
|
result_image_pil = None |
|
|
|
|
|
|
|
|
pattern = re.compile(r"<\|det\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]<\|/det\|>") |
|
|
matches = list(pattern.finditer(text_result)) |
|
|
|
|
|
if matches: |
|
|
print(f"β
Found {len(matches)} bounding box(es). Drawing on the original image.") |
|
|
|
|
|
|
|
|
image_with_bboxes = image.copy() |
|
|
|
|
|
w, h = image.size |
|
|
|
|
|
for match in matches: |
|
|
|
|
|
coords_norm = [int(c) for c in match.groups()] |
|
|
x1_norm, y1_norm, x2_norm, y2_norm = coords_norm |
|
|
|
|
|
|
|
|
x1 = int(x1_norm / 1000 * w) |
|
|
y1 = int(y1_norm / 1000 * h) |
|
|
x2 = int(x2_norm / 1000 * w) |
|
|
y2 = int(y2_norm / 1000 * h) |
|
|
|
|
|
|
|
|
image_with_bboxes = image_with_bboxes.crop([x1, y1, x2, y2]) |
|
|
|
|
|
result_image_pil = image_with_bboxes |
|
|
else: |
|
|
|
|
|
print("β οΈ No bounding box coordinates found in text result. Falling back to search for a result image file.") |
|
|
result_image_pil = find_result_image(output_path) |
|
|
|
|
|
return text_result, result_image_pil |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Text Extraction Demo", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# π³ Full Demo of DeepSeek-OCR π³ |
|
|
|
|
|
Use the tabs below to switch between Free OCR and Locate modes. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("Free OCR"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
free_image = gr.Image(type="pil", label="πΌοΈ Upload Image", sources=["upload", "clipboard"]) |
|
|
free_model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], value="Base", label="βοΈ Resolution Size") |
|
|
free_btn = gr.Button("Run Free OCR", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
free_output_text = gr.Textbox(label="π Text Result", lines=15, show_copy_button=True) |
|
|
free_output_image = gr.Image(label="πΌοΈ Image Result (if any)", type="pil") |
|
|
|
|
|
|
|
|
free_ocr = partial(process_ocr_task, task_type="π Free OCR", ref_text="") |
|
|
free_btn.click(fn=free_ocr, inputs=[free_image, free_model_size], outputs=[free_output_text, free_output_image]) |
|
|
|
|
|
with gr.TabItem("Locate"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
loc_image = gr.Image(type="pil", label="πΌοΈ Upload Image", sources=["upload", "clipboard"]) |
|
|
loc_model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], value="Base", label="βοΈ Resolution Size") |
|
|
|
|
|
loc_btn = gr.Button("Locate", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
loc_output_text = gr.Textbox(label="π Text Result", lines=15, show_copy_button=True) |
|
|
loc_output_image = gr.Image(label="πΌοΈ Image Result (if any)", type="pil") |
|
|
|
|
|
|
|
|
pets_detection = partial(process_ocr_task, task_type="π Locate Object by Reference", ref_text="pets") |
|
|
loc_btn.click(fn=pets_detection, inputs=[loc_image, loc_model_size], outputs=[loc_output_text, loc_output_image]) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["doc_markdown.png", "Gundam (Recommended)", "", "π Convert to Markdown"], |
|
|
["chart.png", "Gundam (Recommended)", "", "π Parse Figure"], |
|
|
["teacher.jpg", "Base", "the teacher", "π Locate Object by Reference"], |
|
|
["math_locate.jpg", "Small", "20-10", "π Locate Object by Reference"], |
|
|
["receipt.jpg", "Base", "", "π Free OCR"], |
|
|
], |
|
|
inputs=[free_image, free_model_size], |
|
|
outputs=[free_output_text, free_output_image], |
|
|
fn=process_ocr_task, |
|
|
cache_examples=False, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if not os.path.exists("examples"): |
|
|
os.makedirs("examples") |
|
|
|
|
|
|
|
|
|
|
|
demo.queue(max_size=20).launch(share=True) |