Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,085 Bytes
1d3853f 20abfc4 1d3853f 1320fad 1d3853f 3cf595b 1d3853f 8a8d449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import os
os.system('pip install -U transformers==4.44.2')
import sys
import shutil
import torch
import base64
import argparse
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
import spaces
# == download weights ==
tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny')
small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small')
base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base')
os.system("ls -l models/unimernet_tiny")
os.system("ls -l models/unimernet_small")
os.system("ls -l models/unimernet_base")
# == download weights ==
sys.path.insert(0, os.path.join(os.getcwd(), ".."))
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor
template_html = """<!DOCTYPE html>
<html lang="en" data-lt-installed="true"><head>
<meta charset="UTF-8">
<title>Title</title>
<script>
const text =
</script>
<style>
#content {
max-width: 800px;
margin: auto;
}
</style>
<script>
let script = document.createElement('script');
script.src = "https://cdn.jsdelivr.net/npm/mathpix-markdown-it@1.3.6/es5/bundle.js";
document.head.append(script);
script.onload = function() {
const isLoaded = window.loadMathJax();
if (isLoaded) {
console.log('Styles loaded!')
}
const el = window.document.getElementById('content-text');
if (el) {
const options = {
htmlTags: true
};
const html = window.render(text, options);
el.outerHTML = html;
}
};
</script>
</head>
<body>
<div id="content"><div id="content-text"></div></div>
</body>
</html>
"""
def latex2html(latex_code):
latex_code = '\\[' + latex_code + '\\]'
latex_code = latex_code.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
latex_code = latex_code.replace('"', '``').replace('$', '')
latex_code_list = latex_code.split('\n')
gt= ''
for out in latex_code_list:
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
gt = gt[:-2]
lines = template_html.split("const text =")
new_web = lines[0] + 'const text =' + gt + lines[1]
return new_web
def load_model_and_processor(cfg_path):
args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args)
task = tasks.setup_task(cfg)
model = task.build_model(cfg)
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
return model, vis_processor
@spaces.GPU
def recognize_image(input_img, model_type):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_type == "base":
model = model_base.to(device)
elif model_type == "small":
model = model_small.to(device)
else:
model = model_tiny.to(device)
model.eval()
if len(input_img.shape) == 3:
input_img = input_img[:, :, ::-1].copy()
img = Image.fromarray(input_img)
image = vis_processor(img).unsqueeze(0).to(device)
output = model.generate({"image": image})
latex_code = output["pred_str"][0]
html_code = latex2html(latex_code)
encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8')
iframe_src = f"data:text/html;base64,{encoded_html}"
iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
return latex_code, iframe
def gradio_reset():
return gr.update(value=None), gr.update(value=None), gr.update(value=None)
if __name__ == "__main__":
root_path = os.path.abspath(os.getcwd())
# == load model ==
print("load tiny model ...")
model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
print("load small model ...")
model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
print("load base model ...")
model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
print("== load all models done. ==")
# == load model ==
with open("header.html", "r") as file:
header = file.read()
with gr.Blocks() as demo:
gr.HTML(header)
with gr.Row():
with gr.Column():
model_type = gr.Radio(
choices=["tiny", "small", "base"],
value="tiny",
label="Model Type",
interactive=True,
)
input_img = gr.Image(label=" ", interactive=True)
with gr.Row():
clear = gr.Button("Clear")
predict = gr.Button(value="Recognize", interactive=True, variant="primary")
with gr.Accordion("Examples:"):
example_root = os.path.join(os.path.dirname(__file__), "examples")
gr.Examples(
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
_.endswith("png")],
inputs=input_img,
)
with gr.Column():
gr.Button(value="Predict Result:", interactive=False)
pred_latex = gr.Textbox(label='Predict Latex', interactive=False)
output_html = gr.HTML(label="Rendered html", show_label=True)
clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex, output_html])
predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex, output_html])
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) |