File size: 2,365 Bytes
33acd44 80f5a2b 33acd44 c5a0dc3 84bad2a 33acd44 1595840 33acd44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import sys
import os
import pandas as pd
import numpy as np
import shutil
from tqdm import tqdm
import re
from donut import DonutModel
import torch
from PIL import Image
import gradio as gr
#from train import *
#en_model_path = "question_generator_by_en_on_pic"
zh_model_path = "question_generator_by_zh_on_pic"
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
#en_pretrained_model = DonutModel.from_pretrained(en_model_path)
#zh_pretrained_model = DonutModel.from_pretrained(zh_model_path)
zh_pretrained_model = DonutModel.from_pretrained(zh_model_path, ignore_mismatched_sizes=True)
'''
if torch.cuda.is_available():
en_pretrained_model.half()
device = torch.device("cuda")
en_pretrained_model.to(device)
'''
if torch.cuda.is_available():
zh_pretrained_model.half()
device = torch.device("cuda")
zh_pretrained_model.to(device)
else:
import torch
zh_pretrained_model.encoder.to(torch.bfloat16)
#en_pretrained_model.eval()
zh_pretrained_model.eval()
print("have load !")
def demo_process_vqa(input_img, question):
#global pretrained_model, task_prompt, task_name
#global zh_pretrained_model, en_pretrained_model, task_prompt, task_name
input_img = Image.fromarray(input_img)
global zh_pretrained_model, task_prompt
user_prompt = task_prompt.replace("{user_input}", question)
output = zh_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
'''
if lang == "en":
output = en_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
else:
output = zh_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
'''
req = {
"question": output["answer"],
"answer": output["question"]
}
return req
'''
img_path = "imgs/en_img.png"
demo_process_vqa(Image.open(img_path), "605-7227", "en")
img_path = "imgs/zh_img.png"
demo_process_vqa(Image.open(img_path), "้ถ้ฑ้", "zh")
'''
example_sample = [["zh_img.png", "้ถ้ฑ้"]]
demo=gr.Interface(fn=demo_process_vqa, inputs=['image','text'],
outputs=["json"],
examples=example_sample if example_sample else None,
description = 'This _example_ was **drive** from <br/><b><h4>[https://github.com/svjack/docvqa-gen](https://github.com/svjack/docvqa-gen)</h4></b>\n',
cache_examples = False
)
demo.launch(share=False) |