ui-refexp / app.py
ivelin
fix: switch model checkpoint
a2780d6
raw
history blame
6.08 kB
import re
import gradio as gr
from PIL import Image, ImageDraw
import math
import torch
import html
from transformers import DonutProcessor, VisionEncoderDecoderModel
pretrained_repo_name = "ivelin/donut-refexp-draft-precision2decs"
print(f"Loading model checkpoint: {pretrained_repo_name}")
processor = DonutProcessor.from_pretrained(pretrained_repo_name)
model = VisionEncoderDecoderModel.from_pretrained(pretrained_repo_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def process_refexp(image: Image, prompt: str):
print(f"(image, prompt): {image}, {prompt}")
# trim prompt to 80 characters and normalize to lowercase
prompt = prompt[:80].lower()
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_bounding_box>"
prompt = task_prompt.replace("{user_input}", prompt)
decoder_input_ids = processor.tokenizer(
prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
print(fr"predicted decoder sequence: {html.escape(sequence)}")
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, "")
# remove first task start token
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
print(
fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
seqjson = processor.token2json(sequence)
# safeguard in case predicted sequence does not include a target_bounding_box token
bbox = seqjson.get('target_bounding_box')
if bbox is None:
print(
f"token2bbox seq has no predicted target_bounding_box, seq:{seq}")
bbox = {"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0}
return bbox
print(f"predicted bounding box with text coordinates: {bbox}")
# safeguard in case text prediction is missing some bounding box coordinates
# or coordinates are not valid numeric values
try:
xmin = float(bbox.get("xmin", 0))
except ValueError:
xmin = 0
try:
ymin = float(bbox.get("ymin", 0))
except ValueError:
ymin = 0
try:
xmax = float(bbox.get("xmax", 1))
except ValueError:
xmax = 1
try:
ymax = float(bbox.get("ymax", 1))
except ValueError:
ymax = 1
# replace str with float coords
bbox = {"xmin": xmin, "ymin": ymin, "xmax": xmax,
"ymax": ymax, "decoder output sequence": sequence}
print(f"predicted bounding box with float coordinates: {bbox}")
print(f"image object: {image}")
print(f"image size: {image.size}")
width, height = image.size
print(f"image width, height: {width, height}")
print(f"processed prompt: {prompt}")
# safeguard in case text prediction is missing some bounding box coordinates
xmin = math.floor(width*bbox["xmin"])
ymin = math.floor(height*bbox["ymin"])
xmax = math.floor(width*bbox["xmax"])
ymax = math.floor(height*bbox["ymax"])
print(
f"to image pixel values: xmin, ymin, xmax, ymax: {xmin, ymin, xmax, ymax}")
shape = [(xmin, ymin), (xmax, ymax)]
# deaw bbox rectangle
img1 = ImageDraw.Draw(image)
img1.rectangle(shape, outline="green", width=5)
img1.rectangle(shape, outline="white", width=2)
return image, bbox
title = "Demo: Donut 🍩 for UI RefExp (by GuardianUI)"
description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on UIBert RefExp Dataset (UI Referring Expression). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
examples = [["example_1.jpg", "select the setting icon from top right corner"],
["example_1.jpg", "click on down arrow beside the entertainment"],
["example_1.jpg", "select the down arrow button beside lifestyle"],
["example_1.jpg", "click on the image beside the option traffic"],
["example_2.jpg", "enter the text field next to the name"],
["example_2.jpg", "click on green color button"],
["example_2.jpg", "click on text which is beside call now"],
["example_2.jpg", "click on more button"],
["example_3.jpg", "select the third row first image"],
["example_3.jpg", "click the tick mark on the first image"],
["example_3.jpg", "select the ninth image"],
["example_3.jpg", "select the add icon"],
["example_3.jpg", "click the first image"],
["example_3.jpg", "select the first column second image"],
["example_3.jpg", "select the bottom right image"],
["example_3.jpg", "select the second row second image"],
]
demo = gr.Interface(fn=process_refexp,
inputs=[gr.Image(type="pil"), "text"],
outputs=[gr.Image(type="pil"), "json"],
title=title,
description=description,
article=article,
examples=examples,
cache_examples=True
)
demo.launch()