Spaces:
Runtime error
Runtime error
File size: 4,357 Bytes
cc5f7b7 61d07b2 cc5f7b7 a47bf95 a671104 a47bf95 cc5f7b7 a47bf95 cc5f7b7 a47bf95 bb7b589 9eb5356 cad4ba9 c0d1798 cad4ba9 ec0123f 9eb5356 a47bf95 a671104 a47bf95 a671104 a47bf95 a671104 a47bf95 cc5f7b7 a68694f a671104 cc5f7b7 a47bf95 a671104 cc5f7b7 a671104 cc5f7b7 a47bf95 cc5f7b7 a47bf95 cc5f7b7 a47bf95 cc5f7b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
os.system('pip install -U transformers==4.44.2')
import sys
import shutil
import torch
import argparse
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
# == download weights ==
tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny')
small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small')
base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base')
os.system("ls -l models/unimernet_tiny")
# os.system(f"sed -i 's/MODEL_DIR/{tiny_model_dir}/g' cfg_tiny.yaml")
# os.system(f"sed -i 's/MODEL_DIR/{small_model_dir}/g' cfg_small.yaml")
# os.system(f"sed -i 's/MODEL_DIR/{base_model_dir}/g' cfg_base.yaml")
# root_path = os.path.abspath(os.getcwd())
# os.makedirs(os.path.join(root_path, "models"), exist_ok=True)
# shutil.move(tiny_model_dir, os.path.join(root_path, "models", "unimernet_tiny"))
# shutil.move(small_model_dir, os.path.join(root_path, "models", "unimernet_small"))
# shutil.move(base_model_dir, os.path.join(root_path, "models", "unimernet_base"))
# == download weights ==
sys.path.insert(0, os.path.join(os.getcwd(), ".."))
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor
def load_model_and_processor(cfg_path):
args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args)
task = tasks.setup_task(cfg)
model = task.build_model(cfg)
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
return model, vis_processor
def recognize_image(input_img, model_type):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_type == "base":
model = model_base.to(device)
elif model_type == "small":
model = model_small.to(device)
else:
model = model_tiny.to(device)
if len(input_img.shape) == 3:
input_img = input_img[:, :, ::-1].copy()
img = Image.fromarray(input_img)
image = vis_processor(img).unsqueeze(0).to(device)
output = model.generate({"image": image})
latex_code = output["pred_str"][0]
return latex_code
def gradio_reset():
return gr.update(value=None), gr.update(value=None)
if __name__ == "__main__":
root_path = os.path.abspath(os.getcwd())
# == load model ==
model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
print("== load all models ==")
# == load model ==
with open("header.html", "r") as file:
header = file.read()
with gr.Blocks() as demo:
gr.HTML(header)
with gr.Row():
with gr.Column():
model_type = gr.Radio(
choices=["tiny", "small", "base"],
value="tiny",
label="Model Type",
interactive=True,
)
input_img = gr.Image(label=" ", interactive=True)
with gr.Row():
clear = gr.Button("Clear")
predict = gr.Button(value="Recognize", interactive=True, variant="primary")
with gr.Accordion("Examples:"):
example_root = os.path.join(os.path.dirname(__file__), "examples")
gr.Examples(
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
_.endswith("png")],
inputs=input_img,
)
with gr.Column():
gr.Button(value="Predict Latex:", interactive=False)
pred_latex = gr.Textbox(label='Latex', interactive=False)
clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex])
predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex])
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) |