ui-refexp / app.py
ivelin
chore: more examples
161ad07
raw history blame
No virus
4.45 kB
import re
import gradio as gr
from PIL import Image, ImageDraw
import math
import torch
import html
from transformers import DonutProcessor, VisionEncoderDecoderModel
pretrained_repo_name = "ivelin/donut-refexp-draft"
processor = DonutProcessor.from_pretrained(pretrained_repo_name)
model = VisionEncoderDecoderModel.from_pretrained(pretrained_repo_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def process_refexp(image: Image, prompt: str):
print(f"(image, prompt): {image}, {prompt}")
# trim prompt to 80 characters and normalize to lowercase
prompt = prompt[:80].lower()
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_bounding_box>"
prompt = task_prompt.replace("{user_input}", prompt)
decoder_input_ids = processor.tokenizer(
prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
print(fr"predicted decoder sequence: {html.escape(sequence)}")
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, "")
# remove first task start token
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
print(
fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
bbox = processor.token2json(sequence)
bbox = bbox['target_bounding_box']
print(f"predicted bounding box: {bbox}")
print(f"image object: {image}")
print(f"image size: {image.size}")
width, height = image.size
print(f"image width, height: {width, height}")
print(f"processed prompt: {prompt}")
# safeguard in case text prediction is missing some bounding box coordinates
xmin = math.floor(width*float(bbox.get("xmin", 0)))
ymin = math.floor(height*float(bbox.get("ymin", 0)))
xmax = math.floor(width*float(bbox.get("xmax", 1)))
ymax = math.floor(height*float(bbox.get("ymax", 1)))
print(
f"to image pixel values: xmin, ymin, xmax, ymax: {xmin, ymin, xmax, ymax}")
shape = [(xmin, ymin), (xmax, ymax)]
# create rectangle image
img1 = ImageDraw.Draw(image)
img1.rectangle(shape, outline="green", width=5)
return image, bbox
title = "Demo: Donut 🍩 for UI RefExp"
description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on UIBert RefExp Dataset (UI Referring Expression). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
examples = [["example_1.jpg", "select the setting icon from top right corner"],
["example_1.jpg", "click on down arrow beside the entertainment"],
["example_1.jpg", "select the down arrow button beside lifestyle"],
["example_1.jpg", "click on the image beside the option traffic"],
["example_2.jpg", "enter the text field next to the name"],
["example_2.jpg", "click on green color button"],
["example_2.jpg", "click on text which is beside call now"],
["example_2.jpg", "click on more button"],
]
demo = gr.Interface(fn=process_refexp,
inputs=[gr.Image(type="pil"), "text"],
outputs=[gr.Image(type="pil"), "json"],
title=title,
description=description,
article=article,
examples=examples,
cache_examples=True
)
demo.launch()