Spaces:
Running
Running
import os | |
os.system('pip install -U transformers==4.44.2') | |
import sys | |
import shutil | |
import torch | |
import base64 | |
import argparse | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from huggingface_hub import snapshot_download | |
import spaces | |
# == download weights == | |
tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny') | |
small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small') | |
base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base') | |
os.system("ls -l models/unimernet_tiny") | |
os.system("ls -l models/unimernet_small") | |
os.system("ls -l models/unimernet_base") | |
# == download weights == | |
sys.path.insert(0, os.path.join(os.getcwd(), "..")) | |
from unimernet.common.config import Config | |
import unimernet.tasks as tasks | |
from unimernet.processors import load_processor | |
template_html = """<!DOCTYPE html> | |
<html lang="en" data-lt-installed="true"><head> | |
<meta charset="UTF-8"> | |
<title>Title</title> | |
<script> | |
const text = | |
</script> | |
<style> | |
#content { | |
max-width: 800px; | |
margin: auto; | |
} | |
</style> | |
<script> | |
let script = document.createElement('script'); | |
script.src = "https://cdn.jsdelivr.net/npm/mathpix-markdown-it@1.3.6/es5/bundle.js"; | |
document.head.append(script); | |
script.onload = function() { | |
const isLoaded = window.loadMathJax(); | |
if (isLoaded) { | |
console.log('Styles loaded!') | |
} | |
const el = window.document.getElementById('content-text'); | |
if (el) { | |
const options = { | |
htmlTags: true | |
}; | |
const html = window.render(text, options); | |
el.outerHTML = html; | |
} | |
}; | |
</script> | |
</head> | |
<body> | |
<div id="content"><div id="content-text"></div></div> | |
</body> | |
</html> | |
""" | |
def latex2html(latex_code): | |
right_num = latex_code.count('\\right') | |
left_num = latex_code.count('\left') | |
if right_num != left_num: | |
latex_code = latex_code.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.') | |
latex_code = latex_code.replace('"', '``').replace('$', '') | |
latex_code_list = latex_code.split('\n') | |
gt= '' | |
for out in latex_code_list: | |
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n' | |
gt = gt[:-2] | |
lines = template_html.split("const text =") | |
new_web = lines[0] + 'const text =' + gt + lines[1] | |
return new_web | |
def load_model_and_processor(cfg_path): | |
args = argparse.Namespace(cfg_path=cfg_path, options=None) | |
cfg = Config(args) | |
task = tasks.setup_task(cfg) | |
model = task.build_model(cfg) | |
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) | |
return model, vis_processor | |
def recognize_image(input_img, model_type): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if model_type == "base": | |
model = model_base.to(device) | |
elif model_type == "small": | |
model = model_small.to(device) | |
else: | |
model = model_tiny.to(device) | |
if len(input_img.shape) == 3: | |
input_img = input_img[:, :, ::-1].copy() | |
img = Image.fromarray(input_img) | |
image = vis_processor(img).unsqueeze(0).to(device) | |
output = model.generate({"image": image}) | |
latex_code = output["pred_str"][0] | |
html_code = latex2html(latex_code) | |
encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8') | |
iframe_src = f"data:text/html;base64,{encoded_html}" | |
iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>' | |
return latex_code, iframe | |
def gradio_reset(): | |
return gr.update(value=None), gr.update(value=None), gr.update(value=None) | |
if __name__ == "__main__": | |
root_path = os.path.abspath(os.getcwd()) | |
# == load model == | |
print("load tiny model ...") | |
model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml")) | |
print("load small model ...") | |
model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml")) | |
print("load base model ...") | |
model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml")) | |
print("== load all models done. ==") | |
# == load model == | |
with open("header.html", "r") as file: | |
header = file.read() | |
with gr.Blocks() as demo: | |
gr.HTML(header) | |
with gr.Row(): | |
with gr.Column(): | |
model_type = gr.Radio( | |
choices=["tiny", "small", "base"], | |
value="tiny", | |
label="Model Type", | |
interactive=True, | |
) | |
input_img = gr.Image(label=" ", interactive=True) | |
with gr.Row(): | |
clear = gr.Button("Clear") | |
predict = gr.Button(value="Recognize", interactive=True, variant="primary") | |
with gr.Accordion("Examples:"): | |
example_root = os.path.join(os.path.dirname(__file__), "examples") | |
gr.Examples( | |
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if | |
_.endswith("png")], | |
inputs=input_img, | |
) | |
with gr.Column(): | |
gr.Button(value="Predict Result:", interactive=False) | |
pred_latex = gr.Textbox(label='Predict Latex', interactive=False) | |
output_html = gr.HTML(label='Output Html') | |
clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex, output_html]) | |
predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex, output_html]) | |
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) |