Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,469 Bytes
cef4f97 22c7b5b cef4f97 d1d0907 1a87a19 02ff46f ee4b3d0 7dcbad8 02ff46f ed456c1 22c7b5b cef4f97 02ff46f cef4f97 00ee90b cef4f97 02ff46f 1a87a19 7dcbad8 ee4b3d0 cef4f97 ee4b3d0 7dcbad8 02ff46f 7dcbad8 40d5755 cef4f97 ee4b3d0 cef4f97 71a766f 02ff46f ee4b3d0 7dcbad8 ee4b3d0 02ff46f ee4b3d0 02ff46f 7dcbad8 cef4f97 ee4b3d0 7dcbad8 ee4b3d0 02ff46f ee4b3d0 02ff46f cef4f97 ee4b3d0 cef4f97 ee4b3d0 cef4f97 f34dca6 7dcbad8 ee4b3d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig
import os
import base64
import spaces
import io
from PIL import Image
import numpy as np
import yaml
from pathlib import Path
from globe import title, description, modelinfor, joinus
import uuid
import tempfile
import time
model_name = 'ucaslcl/GOT-OCR2_0'
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model = model.eval().cuda()
model.config.pad_token_id = tokenizer.eos_token_id
def image_to_base64(image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
results_folder = Path('./results')
results_folder.mkdir(parents=True, exist_ok=True)
@spaces.GPU
def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None):
unique_id = str(uuid.uuid4())
temp_html_path = results_folder / f"{unique_id}.html"
if task == "Plain Text OCR":
res = model.chat(tokenizer, image, ocr_type='ocr')
return res, None, unique_id
else:
if task == "Format Text OCR":
res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file=str(temp_html_path))
elif task == "Fine-grained OCR (Box)":
res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_box=ocr_box, render=True, save_render_file=str(temp_html_path))
elif task == "Fine-grained OCR (Color)":
res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_color=ocr_color, render=True, save_render_file=str(temp_html_path))
elif task == "Multi-crop OCR":
res = model.chat_crop(tokenizer, image, ocr_type='format', render=True, save_render_file=str(temp_html_path))
elif task == "Render Formatted OCR":
res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file=str(temp_html_path))
if temp_html_path.exists():
with open(temp_html_path, 'r') as f:
html_content = f.read()
return res, html_content, unique_id
else:
return res, None, unique_id
def update_inputs(task):
if task in ["Plain Text OCR", "Format Text OCR", "Multi-crop OCR", "Render Formatted OCR"]:
return [gr.update(visible=False)] * 3
elif task == "Fine-grained OCR (Box)":
return [
gr.update(visible=True, choices=["ocr", "format"]),
gr.update(visible=True),
gr.update(visible=False),
]
elif task == "Fine-grained OCR (Color)":
return [
gr.update(visible=True, choices=["ocr", "format"]),
gr.update(visible=False),
gr.update(visible=True, choices=["red", "green", "blue"]),
]
def ocr_demo(image, task, ocr_type, ocr_box, ocr_color):
res, html_content = process_image(image, task, ocr_type, ocr_box, ocr_color)
res = f"$$ {res} $$"
# res = res.replace("$$ \\begin{tabular}", "\\begin{tabular}")
# res = res.replace("\\end{tabular} $$", "\\end{tabular}")
# res = res.replace("\\(", "")
# res = res.replace("\\)", "")
if html_content:
html_string = f'<iframe srcdoc="{html_content}" width="100%" height="600px"></iframe>'
return res, html_string
return res, None
def cleanup_old_files():
current_time = time.time()
for file_path in results_folder.glob('*.html'):
if current_time - file_path.stat().st_mtime > 3600: # 1 hour
file_path.unlink()
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown(joinus)
with gr.Column():
image_input = gr.Image(type="filepath", label="Input Image")
task_dropdown = gr.Dropdown(
choices=[
"Plain Text OCR",
"Format Text OCR",
"Fine-grained OCR (Box)",
"Fine-grained OCR (Color)",
"Multi-crop OCR",
"Render Formatted OCR"
],
label="Select Task",
value="Plain Text OCR"
)
ocr_type_dropdown = gr.Dropdown(
choices=["ocr", "format"],
label="OCR Type",
visible=False
)
ocr_box_input = gr.Textbox(
label="OCR Box (x1,y1,x2,y2)",
placeholder="[100,100,200,200]",
visible=False
)
ocr_color_dropdown = gr.Dropdown(
choices=["red", "green", "blue"],
label="OCR Color",
visible=False
)
submit_button = gr.Button("Process")
output_markdown = gr.Markdown(label="🫴🏻📸GOT-OCR")
output_html = gr.HTML(label="🫴🏻📸GOT-OCR")
gr.Markdown(modelinfor)
task_dropdown.change(
update_inputs,
inputs=[task_dropdown],
outputs=[ocr_type_dropdown, ocr_box_input, ocr_color_dropdown]
)
submit_button.click(
ocr_demo,
inputs=[image_input, task_dropdown, ocr_type_dropdown, ocr_box_input, ocr_color_dropdown],
outputs=[output_markdown, output_html]
)
if __name__ == "__main__":
cleanup_old_files()
demo.launch() |