Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
from transformers import AutoModel, AutoTokenizer, AutoConfig | |
import os | |
import base64 | |
import io | |
from PIL import Image | |
import numpy as np | |
import uuid | |
import cv2 | |
import re | |
from globe import title, description, modelinfor, joinus, howto | |
model_name = 'ucaslcl/GOT-OCR2_0' | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
model = model.eval().cuda() | |
model.config.pad_token_id = tokenizer.eos_token_id | |
UPLOAD_FOLDER = "./uploads" | |
RESULTS_FOLDER = "./results" | |
for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]: | |
if not os.path.exists(folder): | |
os.makedirs(folder) | |
def image_to_base64(image): | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode() | |
def process_image(image, ocr_type, ocr_box=None, ocr_color=None): | |
unique_id = str(uuid.uuid4()) | |
image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png") | |
result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html") | |
try: | |
if isinstance(image, dict): | |
composite_image = image.get("composite") | |
if composite_image is not None: | |
if isinstance(composite_image, np.ndarray): | |
cv2.imwrite(image_path, cv2.cvtColor(composite_image, cv2.COLOR_RGB2BGR)) | |
elif isinstance(composite_image, Image.Image): | |
composite_image.save(image_path) | |
else: | |
return "Error: Unsupported image format from ImageEditor", None | |
else: | |
return "Error: No composite image found in ImageEditor output", None | |
else: | |
return "Error: Unsupported image format", None | |
if ocr_color: | |
res = model.chat(tokenizer, image_path, ocr_type=ocr_type, ocr_color=ocr_color, render=True, save_render_file=result_path) | |
else: | |
res = model.chat(tokenizer, image_path, ocr_type=ocr_type, ocr_box=ocr_box, render=True, save_render_file=result_path) | |
if os.path.exists(result_path): | |
with open(result_path, 'r') as f: | |
html_content = f.read() | |
return res, html_content | |
else: | |
return res, None | |
except Exception as e: | |
return f"Error: {str(e)}", None | |
finally: | |
if os.path.exists(image_path): | |
os.remove(image_path) | |
def parse_latex_output(res): | |
lines = re.split(r'(\$\$.*?\$\$)', res, flags=re.DOTALL) | |
parsed_lines = [] | |
in_latex = False | |
latex_buffer = [] | |
for line in lines: | |
if line == '\n': | |
if in_latex: | |
latex_buffer.append(line) | |
else: | |
parsed_lines.append(line) | |
continue | |
line = line.strip() | |
latex_patterns = [r'\{', r'\}', r'\[', r'\]', r'\\', r'\$', r'_', r'^', r'"'] | |
contains_latex = any(re.search(pattern, line) for pattern in latex_patterns) | |
if contains_latex: | |
if not in_latex: | |
in_latex = True | |
latex_buffer = ['$$'] | |
latex_buffer.append(line) | |
else: | |
if in_latex: | |
latex_buffer.append('$$') | |
parsed_lines.extend(latex_buffer) | |
in_latex = False | |
latex_buffer = [] | |
parsed_lines.append(line) | |
if in_latex: | |
latex_buffer.append('$$') | |
parsed_lines.extend(latex_buffer) | |
return '$$\\$$\n'.join(parsed_lines) | |
def ocr_demo(image, ocr_type, ocr_color): | |
res, html_content = process_image(image, ocr_type, ocr_color=ocr_color) | |
if isinstance(res, str) and res.startswith("Error:"): | |
return res, None | |
res = res.replace("\\title", "\\title ") | |
formatted_res = parse_latex_output(res) | |
if html_content: | |
encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8') | |
iframe_src = f"data:text/html;base64,{encoded_html}" | |
iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>' | |
download_link = f'<a href="data:text/html;base64,{encoded_html}" download="result_{uuid.uuid4()}.html">Download Full Result</a>' | |
return formatted_res, f"{iframe}<br>{download_link}" | |
return formatted_res, None | |
with gr.Blocks(theme=gr.themes.Base()) as demo: | |
with gr.Row(): | |
gr.Markdown(title) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown(description) | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown(modelinfor) | |
gr.Markdown(joinus) | |
with gr.Row(): | |
with gr.Accordion("How to use 🫴🏻👁GOT OCR", open=True): | |
with gr.Row(): | |
gr.Image("res/image/howto_1.png", label="Select the Following Parameters") | |
gr.Image("res/image/howto_2.png", label="Click on Paintbrush in the Image Editor") | |
gr.Image("res/image/howto_3.png", label="Select your Brush Color (Red)") | |
gr.Image("res/image/howto_4.png", label="Make a Box Around The Text") | |
with gr.Row(): | |
with gr.Group(): | |
gr.Markdown(howto) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_editor = gr.ImageEditor(label="Image Editor", type="pil", height=800) | |
ocr_type_dropdown = gr.Dropdown( | |
choices=["ocr", "format"], | |
label="OCR Type", | |
value="ocr" | |
) | |
ocr_color_dropdown = gr.Dropdown( | |
choices=["red", "green", "blue"], | |
label="OCR Color", | |
value="red" | |
) | |
submit_button = gr.Button("Process") | |
with gr.Column(scale=1): | |
output_markdown = gr.Markdown(label="🫴🏻👁GOT-OCR") | |
output_html = gr.HTML(label="🫴🏻👁GOT-OCR") | |
submit_button.click( | |
ocr_demo, | |
inputs=[image_editor, ocr_type_dropdown, ocr_color_dropdown], | |
outputs=[output_markdown, output_html] | |
) | |
if __name__ == "__main__": | |
demo.launch() |