File size: 5,469 Bytes
cef4f97
 
22c7b5b
cef4f97
 
 
d1d0907
1a87a19
02ff46f
 
 
ee4b3d0
7dcbad8
 
 
02ff46f
ed456c1
22c7b5b
cef4f97
02ff46f
cef4f97
 
00ee90b
cef4f97
02ff46f
 
 
 
1a87a19
7dcbad8
 
ee4b3d0
cef4f97
ee4b3d0
7dcbad8
 
 
02ff46f
 
7dcbad8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40d5755
cef4f97
ee4b3d0
 
cef4f97
 
 
 
 
 
 
 
 
 
 
 
71a766f
02ff46f
ee4b3d0
7dcbad8
 
 
 
 
ee4b3d0
02ff46f
ee4b3d0
 
02ff46f
7dcbad8
 
 
 
 
 
cef4f97
 
ee4b3d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dcbad8
ee4b3d0
 
 
 
 
 
 
 
02ff46f
ee4b3d0
 
 
 
02ff46f
cef4f97
 
 
ee4b3d0
cef4f97
 
 
 
 
ee4b3d0
cef4f97
 
f34dca6
7dcbad8
ee4b3d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig
import os
import base64
import spaces
import io
from PIL import Image
import numpy as np
import yaml
from pathlib import Path
from globe import title, description, modelinfor, joinus
import uuid
import tempfile
import time

model_name = 'ucaslcl/GOT-OCR2_0'

tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model = model.eval().cuda()
model.config.pad_token_id = tokenizer.eos_token_id

def image_to_base64(image):
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode()

results_folder = Path('./results')
results_folder.mkdir(parents=True, exist_ok=True)

@spaces.GPU
def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None):
    unique_id = str(uuid.uuid4())
    temp_html_path = results_folder / f"{unique_id}.html"
    
    if task == "Plain Text OCR":
        res = model.chat(tokenizer, image, ocr_type='ocr')
        return res, None, unique_id
    else:
        if task == "Format Text OCR":
            res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file=str(temp_html_path))
        elif task == "Fine-grained OCR (Box)":
            res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_box=ocr_box, render=True, save_render_file=str(temp_html_path))
        elif task == "Fine-grained OCR (Color)":
            res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_color=ocr_color, render=True, save_render_file=str(temp_html_path))
        elif task == "Multi-crop OCR":
            res = model.chat_crop(tokenizer, image, ocr_type='format', render=True, save_render_file=str(temp_html_path))
        elif task == "Render Formatted OCR":
            res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file=str(temp_html_path))
        
        if temp_html_path.exists():
            with open(temp_html_path, 'r') as f:
                html_content = f.read()
            return res, html_content, unique_id
        else:
            return res, None, unique_id
    
def update_inputs(task):
    if task in ["Plain Text OCR", "Format Text OCR", "Multi-crop OCR", "Render Formatted OCR"]:
        return [gr.update(visible=False)] * 3
    elif task == "Fine-grained OCR (Box)":
        return [
            gr.update(visible=True, choices=["ocr", "format"]),
            gr.update(visible=True),
            gr.update(visible=False),
        ]
    elif task == "Fine-grained OCR (Color)":
        return [
            gr.update(visible=True, choices=["ocr", "format"]),
            gr.update(visible=False),
            gr.update(visible=True, choices=["red", "green", "blue"]),
        ]
def ocr_demo(image, task, ocr_type, ocr_box, ocr_color):
    res, html_content = process_image(image, task, ocr_type, ocr_box, ocr_color)
    
    res = f"$$ {res} $$"
    # res = res.replace("$$ \\begin{tabular}", "\\begin{tabular}")
    # res = res.replace("\\end{tabular} $$", "\\end{tabular}")
    # res = res.replace("\\(", "")
    # res = res.replace("\\)", "")
    
    if html_content:
        html_string = f'<iframe srcdoc="{html_content}" width="100%" height="600px"></iframe>'
        return res, html_string
    return res, None

def cleanup_old_files():
    current_time = time.time()
    for file_path in results_folder.glob('*.html'):
        if current_time - file_path.stat().st_mtime > 3600:  # 1 hour
            file_path.unlink()

with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Markdown(joinus)
    
    with gr.Column():
        image_input = gr.Image(type="filepath", label="Input Image")
        task_dropdown = gr.Dropdown(
            choices=[
                "Plain Text OCR",
                "Format Text OCR",
                "Fine-grained OCR (Box)",
                "Fine-grained OCR (Color)",
                "Multi-crop OCR",
                "Render Formatted OCR"
            ],
            label="Select Task",
            value="Plain Text OCR"
        )
        ocr_type_dropdown = gr.Dropdown(
            choices=["ocr", "format"],
            label="OCR Type",
            visible=False
        )
        ocr_box_input = gr.Textbox(
            label="OCR Box (x1,y1,x2,y2)",
            placeholder="[100,100,200,200]",
            visible=False
        )
        ocr_color_dropdown = gr.Dropdown(
            choices=["red", "green", "blue"],
            label="OCR Color",
            visible=False
        )
        submit_button = gr.Button("Process")

        output_markdown = gr.Markdown(label="🫴🏻📸GOT-OCR")
        output_html = gr.HTML(label="🫴🏻📸GOT-OCR")

    gr.Markdown(modelinfor)

    task_dropdown.change(
        update_inputs,
        inputs=[task_dropdown],
        outputs=[ocr_type_dropdown, ocr_box_input, ocr_color_dropdown]
    )
    
    submit_button.click(
        ocr_demo,
        inputs=[image_input, task_dropdown, ocr_type_dropdown, ocr_box_input, ocr_color_dropdown],
        outputs=[output_markdown, output_html]
    )

if __name__ == "__main__":
    cleanup_old_files()
    demo.launch()