Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoModel, AutoTokenizer, AutoConfig | |
import os | |
import base64 | |
import spaces | |
import io | |
from PIL import Image | |
import numpy as np | |
import yaml | |
from pathlib import Path | |
from globe import title, description, modelinfor, joinus | |
import uuid | |
import tempfile | |
import time | |
model_name = 'ucaslcl/GOT-OCR2_0' | |
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) | |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
model = model.eval().cuda() | |
model.config.pad_token_id = tokenizer.eos_token_id | |
def image_to_base64(image): | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode() | |
results_folder = Path('./results') | |
results_folder.mkdir(parents=True, exist_ok=True) | |
def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None): | |
unique_id = str(uuid.uuid4()) | |
temp_html_path = results_folder / f"{unique_id}.html" | |
if task == "Plain Text OCR": | |
res = model.chat(tokenizer, image, ocr_type='ocr') | |
return res, None, unique_id | |
else: | |
if task == "Format Text OCR": | |
res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file=str(temp_html_path)) | |
elif task == "Fine-grained OCR (Box)": | |
res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_box=ocr_box, render=True, save_render_file=str(temp_html_path)) | |
elif task == "Fine-grained OCR (Color)": | |
res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_color=ocr_color, render=True, save_render_file=str(temp_html_path)) | |
elif task == "Multi-crop OCR": | |
res = model.chat_crop(tokenizer, image, ocr_type='format', render=True, save_render_file=str(temp_html_path)) | |
elif task == "Render Formatted OCR": | |
res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file=str(temp_html_path)) | |
if temp_html_path.exists(): | |
with open(temp_html_path, 'r') as f: | |
html_content = f.read() | |
return res, html_content, unique_id | |
else: | |
return res, None, unique_id | |
def update_inputs(task): | |
if task in ["Plain Text OCR", "Format Text OCR", "Multi-crop OCR", "Render Formatted OCR"]: | |
return [gr.update(visible=False)] * 3 | |
elif task == "Fine-grained OCR (Box)": | |
return [ | |
gr.update(visible=True, choices=["ocr", "format"]), | |
gr.update(visible=True), | |
gr.update(visible=False), | |
] | |
elif task == "Fine-grained OCR (Color)": | |
return [ | |
gr.update(visible=True, choices=["ocr", "format"]), | |
gr.update(visible=False), | |
gr.update(visible=True, choices=["red", "green", "blue"]), | |
] | |
def ocr_demo(image, task, ocr_type, ocr_box, ocr_color): | |
res, html_content = process_image(image, task, ocr_type, ocr_box, ocr_color) | |
res = f"$$ {res} $$" | |
# res = res.replace("$$ \\begin{tabular}", "\\begin{tabular}") | |
# res = res.replace("\\end{tabular} $$", "\\end{tabular}") | |
# res = res.replace("\\(", "") | |
# res = res.replace("\\)", "") | |
if html_content: | |
html_string = f'<iframe srcdoc="{html_content}" width="100%" height="600px"></iframe>' | |
return res, html_string | |
return res, None | |
def cleanup_old_files(): | |
current_time = time.time() | |
for file_path in results_folder.glob('*.html'): | |
if current_time - file_path.stat().st_mtime > 3600: # 1 hour | |
file_path.unlink() | |
with gr.Blocks() as demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
gr.Markdown(joinus) | |
with gr.Column(): | |
image_input = gr.Image(type="filepath", label="Input Image") | |
task_dropdown = gr.Dropdown( | |
choices=[ | |
"Plain Text OCR", | |
"Format Text OCR", | |
"Fine-grained OCR (Box)", | |
"Fine-grained OCR (Color)", | |
"Multi-crop OCR", | |
"Render Formatted OCR" | |
], | |
label="Select Task", | |
value="Plain Text OCR" | |
) | |
ocr_type_dropdown = gr.Dropdown( | |
choices=["ocr", "format"], | |
label="OCR Type", | |
visible=False | |
) | |
ocr_box_input = gr.Textbox( | |
label="OCR Box (x1,y1,x2,y2)", | |
placeholder="[100,100,200,200]", | |
visible=False | |
) | |
ocr_color_dropdown = gr.Dropdown( | |
choices=["red", "green", "blue"], | |
label="OCR Color", | |
visible=False | |
) | |
submit_button = gr.Button("Process") | |
output_markdown = gr.Markdown(label="🫴🏻📸GOT-OCR") | |
output_html = gr.HTML(label="🫴🏻📸GOT-OCR") | |
gr.Markdown(modelinfor) | |
task_dropdown.change( | |
update_inputs, | |
inputs=[task_dropdown], | |
outputs=[ocr_type_dropdown, ocr_box_input, ocr_color_dropdown] | |
) | |
submit_button.click( | |
ocr_demo, | |
inputs=[image_input, task_dropdown, ocr_type_dropdown, ocr_box_input, ocr_color_dropdown], | |
outputs=[output_markdown, output_html] | |
) | |
if __name__ == "__main__": | |
cleanup_old_files() | |
demo.launch() |