donut-mrz / app.py
adbcode's picture
first draft
e05a1d1
raw
history blame contribute delete
No virus
2.1 kB
import argparse
import gradio as gr
import os
import torch
from donut import DonutModel
from PIL import Image
def demo_process_vqa(input_img, question):
global pretrained_model, task_prompt, task_name
input_img = Image.fromarray(input_img)
user_prompt = task_prompt.replace("{user_input}", question)
return pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
def demo_process(input_img):
global pretrained_model, task_prompt, task_name
input_img = Image.fromarray(input_img)
best_output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
return best_output["text_sequence"].split(" </s_MachineReadableZone>")[0]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="s_passport")
parser.add_argument("--pretrained_path", type=str, default=os.getcwd())
parser.add_argument("--port", type=int, default=12345)
parser.add_argument("--url", type=str, default="0.0.0.0")
parser.add_argument("--sample_img_path", type=str)
args, left_argv = parser.parse_known_args()
task_name = args.task
if "docvqa" == task_name:
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
else: # rvlcdip, cord, ...
task_prompt = f"<s_{task_name}>"
example_sample = [os.path.join("images", image) for image in os.listdir("images")]
if args.sample_img_path:
example_sample.append(args.sample_img_path)
pretrained_model = DonutModel.from_pretrained(args.pretrained_path)
if torch.cuda.is_available():
pretrained_model.half()
device = torch.device("cuda")
pretrained_model.to(device)
pretrained_model.eval()
gr.Interface(
fn=demo_process_vqa if task_name == "docvqa" else demo_process,
inputs=["image", "text"] if task_name == "docvqa" else "image",
outputs="text",
title="Demo of MRZ Extraction model based on 🍩 architecture",
examples=example_sample if example_sample else None
).launch()