|
import sys |
|
import os |
|
import pandas as pd |
|
import numpy as np |
|
import shutil |
|
|
|
from tqdm import tqdm |
|
import re |
|
|
|
from donut import DonutModel |
|
import torch |
|
from PIL import Image |
|
import gradio as gr |
|
|
|
|
|
|
|
zh_model_path = "question_generator_by_zh_on_pic" |
|
|
|
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>" |
|
|
|
|
|
zh_pretrained_model = DonutModel.from_pretrained(zh_model_path, ignore_mismatched_sizes=True) |
|
''' |
|
if torch.cuda.is_available(): |
|
en_pretrained_model.half() |
|
device = torch.device("cuda") |
|
en_pretrained_model.to(device) |
|
|
|
''' |
|
if torch.cuda.is_available(): |
|
zh_pretrained_model.half() |
|
device = torch.device("cuda") |
|
zh_pretrained_model.to(device) |
|
else: |
|
import torch |
|
zh_pretrained_model.encoder.to(torch.bfloat16) |
|
|
|
|
|
|
|
zh_pretrained_model.eval() |
|
print("have load !") |
|
|
|
def demo_process_vqa(input_img, question): |
|
|
|
|
|
input_img = Image.fromarray(input_img) |
|
global zh_pretrained_model, task_prompt |
|
user_prompt = task_prompt.replace("{user_input}", question) |
|
output = zh_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0] |
|
''' |
|
if lang == "en": |
|
output = en_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0] |
|
else: |
|
output = zh_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0] |
|
''' |
|
req = { |
|
"question": output["answer"], |
|
"answer": output["question"] |
|
} |
|
return req |
|
|
|
''' |
|
img_path = "imgs/en_img.png" |
|
demo_process_vqa(Image.open(img_path), "605-7227", "en") |
|
img_path = "imgs/zh_img.png" |
|
demo_process_vqa(Image.open(img_path), "้ถ้ฑ้", "zh") |
|
''' |
|
|
|
example_sample = [["zh_img.png", "้ถ้ฑ้"]] |
|
|
|
demo=gr.Interface(fn=demo_process_vqa, inputs=['image','text'], |
|
outputs=["json"], |
|
examples=example_sample if example_sample else None, |
|
description = 'This _example_ was **drive** from <br/><b><h4>[https://github.com/svjack/docvqa-gen](https://github.com/svjack/docvqa-gen)</h4></b>\n', |
|
cache_examples = False |
|
) |
|
demo.launch(share=False) |