Spaces:
Runtime error
Runtime error
| import argparse | |
| import json | |
| import os | |
| import torch | |
| from PIL import Image | |
| from qwen_vl_utils import process_vision_info | |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration | |
| import gradio as gr | |
| user_prompt = "Analyze the image. Extract and output only the LaTeX formulas present in the image, in LaTeX code format. Ignore inline formulas, all other text, and do not include any explanations." | |
| def read_input_file(input_file): | |
| with open(input_file, 'r') as file: | |
| data = json.load(file) | |
| image_path = data[0]['images'][0] | |
| gt_latex_code = data[0]['messages'][1]['content'] | |
| return image_path, gt_latex_code | |
| class ImageProcessor: | |
| def __init__(self, args): | |
| self.args = args | |
| self.model, self.vis_processor = self.load_model_and_processor() | |
| self.generate_kwargs = dict( | |
| max_new_tokens=2048, | |
| top_p=0.001, | |
| top_k=1, | |
| temperature=0.01, | |
| repetition_penalty=1.0, | |
| ) | |
| def load_model_and_processor(self): | |
| # Load model | |
| checkpoint = self.args.ckpt | |
| vis_processor = AutoProcessor.from_pretrained(checkpoint) | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained(checkpoint, torch_dtype="auto", device_map="auto") | |
| model.eval() | |
| return model, vis_processor | |
| def process_single_image(self, image_path): | |
| question = user_prompt | |
| try: | |
| image_local_path = "file://" + image_path | |
| messages = [] | |
| messages.append( | |
| {"role": "user", "content": [ | |
| {"type": "image", "image": image_local_path, "min_pixels": 32 * 32, "max_pixels": 512 * 512}, | |
| {"type": "text", "text": question}, | |
| ] | |
| } | |
| ) | |
| text = self.vis_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| images, videos = process_vision_info([messages]) | |
| inputs = self.vis_processor(text=text, images=images, videos=videos, padding=True, return_tensors='pt') | |
| inputs = inputs.to(self.model.device) | |
| with torch.no_grad(): | |
| generated_ids = self.model.generate( | |
| **inputs, | |
| **self.generate_kwargs, | |
| ) | |
| generated_ids = [ | |
| output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| out = self.vis_processor.tokenizer.batch_decode( | |
| generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| model_answer = out[0] | |
| except Exception as e: | |
| print(e, flush=True) | |
| model_answer = "None" | |
| return model_answer | |
| def save_image_with_auto_naming(image, save_dir="./tmp"): | |
| # 确保目录存在 | |
| os.makedirs(save_dir, exist_ok=True) | |
| # 获取目录中现有的文件名 | |
| existing_files = [f for f in os.listdir(save_dir) if f.endswith('.png') and f.split('.')[0].isdigit()] | |
| # 找到最大的数字 | |
| next_num = 0 | |
| if existing_files: | |
| next_num = max([int(f.split('.')[0]) for f in existing_files]) + 1 | |
| # 生成新文件名 | |
| temp_path = os.path.join(save_dir, f"{next_num}.png") | |
| # 保存图片 | |
| image.save(temp_path) | |
| return temp_path | |
| # {{ edit_1 }} | |
| def process_image_for_gradio(image): | |
| """处理上传的图片并返回LaTeX结果""" | |
| if image is None: | |
| return "" | |
| # 保存上传的图片到指定目录,并自动命名 | |
| temp_path = save_image_with_auto_naming(image) | |
| # 处理图片 | |
| pred_latex_code = processor.process_single_image(temp_path) | |
| # 清理临时文件 | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| return pred_latex_code | |
| def load_example(example_name): | |
| """加载示例图片""" | |
| input_file = os.path.join('./asset/test_jsons', f"{example_name}.json") | |
| image_path, gt_latex_code = read_input_file(input_file) | |
| return Image.open(image_path), example_name | |
| # {{ edit_2 }} | |
| def create_gradio_interface(processor): | |
| """创建Gradio界面""" | |
| with gr.Blocks(title="DocTron-Formula") as demo: | |
| gr.Markdown("# DocTron-Formula LaTeX公式识别") | |
| gr.Markdown("上传图片或选择示例来识别LaTeX公式") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # 左侧列 | |
| image_input = gr.Image(type="pil", label="上传图片") | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear") | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| gr.Markdown("### 示例图片") | |
| with gr.Row(): | |
| line_btn = gr.Button("Line-level") | |
| paragraph_btn = gr.Button("Paragraph-level") | |
| page_btn = gr.Button("Page-level") | |
| # 存储示例名称 | |
| example_name = gr.State() | |
| with gr.Column(): | |
| # 右侧列 - 显示结果 | |
| latex_output = gr.Textbox(label="预测的LaTeX公式", lines=10, interactive=False) | |
| # 按钮事件绑定 | |
| submit_btn.click( | |
| fn=process_image_for_gradio, | |
| inputs=[image_input], | |
| outputs=[latex_output] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, ""), | |
| inputs=[], | |
| outputs=[image_input, latex_output] | |
| ) | |
| # 示例按钮事件 | |
| line_btn.click( | |
| fn=load_example, | |
| inputs=gr.Textbox(value="line-level", visible=False), | |
| outputs=[image_input, example_name] | |
| ).then( | |
| fn=lambda img: process_image_for_gradio(img), | |
| inputs=[image_input], | |
| outputs=[latex_output] | |
| ) | |
| paragraph_btn.click( | |
| fn=load_example, | |
| inputs=gr.Textbox(value="paragraph-level", visible=False), | |
| outputs=[image_input, example_name] | |
| ).then( | |
| fn=lambda img: process_image_for_gradio(img), | |
| inputs=[image_input], | |
| outputs=[latex_output] | |
| ) | |
| page_btn.click( | |
| fn=load_example, | |
| inputs=gr.Textbox(value="page-level", visible=False), | |
| outputs=[image_input, example_name] | |
| ).then( | |
| fn=lambda img: process_image_for_gradio(img), | |
| inputs=[image_input], | |
| outputs=[latex_output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--ckpt", type=str, default="DocTron/DocTron-Formula") | |
| parser.add_argument("--input_file", type=str, default="line-level") | |
| args = parser.parse_args() | |
| # Init model | |
| processor = ImageProcessor(args) | |
| # {{ edit_3 }} | |
| # 创建并启动Gradio界面 | |
| demo = create_gradio_interface(processor) | |
| # demo.launch( | |
| # server_name="10.238.36.208", | |
| # server_port=8000, | |
| # share=False | |
| # ) | |
| demo.launch() |