import sys import os import pandas as pd import numpy as np import shutil from tqdm import tqdm import re from donut import DonutModel import torch from PIL import Image import gradio as gr #from train import * en_model_path = "question_generator_by_en_on_pic" #zh_model_path = "question_generator_by_zh_on_pic" task_prompt = "{user_input}" #en_pretrained_model = DonutModel.from_pretrained(en_model_path) #zh_pretrained_model = DonutModel.from_pretrained(zh_model_path) en_pretrained_model = DonutModel.from_pretrained(en_model_path, ignore_mismatched_sizes=True) if torch.cuda.is_available(): en_pretrained_model.half() device = torch.device("cuda") en_pretrained_model.to(device) else: import torch en_pretrained_model.encoder.to(torch.bfloat16) ''' if torch.cuda.is_available(): zh_pretrained_model.half() device = torch.device("cuda") zh_pretrained_model.to(device) ''' en_pretrained_model.eval() #zh_pretrained_model.eval() print("have load !") def demo_process_vqa(input_img, question): #global pretrained_model, task_prompt, task_name #global zh_pretrained_model, en_pretrained_model, task_prompt, task_name input_img = Image.fromarray(input_img) global en_pretrained_model, task_prompt user_prompt = task_prompt.replace("{user_input}", question) output = en_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0] ''' if lang == "en": output = en_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0] else: output = zh_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0] ''' req = { "question": output["answer"], "answer": output["question"] } return req ''' img_path = "imgs/en_img.png" demo_process_vqa(Image.open(img_path), "605-7227", "en") img_path = "imgs/zh_img.png" demo_process_vqa(Image.open(img_path), "零钱通", "zh") ''' example_sample = [["en_img.png", "605-7227"]] demo=gr.Interface(fn=demo_process_vqa, inputs=['image','text'], outputs=["json"], examples=example_sample if example_sample else None, description = 'This _example_ was **drive** from

[https://github.com/svjack/docvqa-gen](https://github.com/svjack/docvqa-gen)

\n', cache_examples = False ) demo.launch(share=False)