ui-refexp / app.py
ivelin's picture
add more examples
562368d
raw
history blame
7.01 kB
import re
import gradio as gr
from PIL import Image, ImageDraw
import math
import torch
import html
from transformers import DonutProcessor, VisionEncoderDecoderModel
pretrained_repo_name = 'ivelin/donut-refexp-combined-v1'
pretrained_revision = 'main'
# revision: '348ddad8e958d370b7e341acd6050330faa0500f' # Iou = 0.47
# revision: '41210d7c42a22e77711711ec45508a6b63ec380f' # : IoU=0.42
# use 'main' for latest revision
print(f"Loading model checkpoint: {pretrained_repo_name}")
processor = DonutProcessor.from_pretrained(pretrained_repo_name, revision=pretrained_revision)
model = VisionEncoderDecoderModel.from_pretrained(pretrained_repo_name, revision=pretrained_revision)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def process_refexp(image: Image, prompt: str):
print(f"(image, prompt): {image}, {prompt}")
# trim prompt to 80 characters and normalize to lowercase
prompt = prompt[:80].lower()
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_bounding_box>"
prompt = task_prompt.replace("{user_input}", prompt)
decoder_input_ids = processor.tokenizer(
prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
print(fr"predicted decoder sequence: {html.escape(sequence)}")
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, "")
# remove first task start token
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
print(
fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
seqjson = processor.token2json(sequence)
# safeguard in case predicted sequence does not include a target_bounding_box token
bbox = seqjson.get('target_bounding_box')
if bbox is None:
print(
f"token2bbox seq has no predicted target_bounding_box, seq:{seq}")
bbox = {"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0}
return bbox
print(f"predicted bounding box with text coordinates: {bbox}")
# safeguard in case text prediction is missing some bounding box coordinates
# or coordinates are not valid numeric values
try:
xmin = float(bbox.get("xmin", 0))
except ValueError:
xmin = 0
try:
ymin = float(bbox.get("ymin", 0))
except ValueError:
ymin = 0
try:
xmax = float(bbox.get("xmax", 1))
except ValueError:
xmax = 1
try:
ymax = float(bbox.get("ymax", 1))
except ValueError:
ymax = 1
# replace str with float coords
bbox = {"xmin": xmin, "ymin": ymin, "xmax": xmax,
"ymax": ymax, "decoder output sequence": sequence}
print(f"predicted bounding box with float coordinates: {bbox}")
print(f"image object: {image}")
print(f"image size: {image.size}")
width, height = image.size
print(f"image width, height: {width, height}")
print(f"processed prompt: {prompt}")
# safeguard in case text prediction is missing some bounding box coordinates
xmin = math.floor(width*bbox["xmin"])
ymin = math.floor(height*bbox["ymin"])
xmax = math.floor(width*bbox["xmax"])
ymax = math.floor(height*bbox["ymax"])
print(
f"to image pixel values: xmin, ymin, xmax, ymax: {xmin, ymin, xmax, ymax}")
shape = [(xmin, ymin), (xmax, ymax)]
# deaw bbox rectangle
img1 = ImageDraw.Draw(image)
img1.rectangle(shape, outline="green", width=5)
img1.rectangle(shape, outline="white", width=2)
return image, bbox
title = "Demo: Donut 🍩 for UI RefExp (by GuardianUI)"
description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on [UIBert RefExp](https://huggingface.co/datasets/ivelin/ui_refexp_saved) Dataset (UI Referring Expression). To use it, simply upload your image and type a prompt and click 'submit', or click one of the examples to load them. See the model training <a href='https://colab.research.google.com/github/ivelin/donut_ui_refexp/blob/main/Fine_tune_Donut_on_UI_RefExp.ipynb' target='_parent'>Colab Notebook</a> for this space. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
examples = [["example_1.jpg", "select the setting icon from top right corner"],
["example_1.jpg", "click on down arrow beside the entertainment"],
["example_1.jpg", "select the down arrow button beside lifestyle"],
["example_1.jpg", "click on the image beside the option traffic"],
["example_2.jpg", "enter the text field next to the name"],
["example_3.jpg", "select the third row first image"],
["example_3.jpg", "click the tick mark on the first image"],
["example_3.jpg", "select the ninth image"],
["example_3.jpg", "select the add icon"],
["example_3.jpg", "click the first image"],
["val-image-4.jpg", 'select 4153365454'],
['val-image-4.jpg', 'go to cell']
['val-image-4.jpg', 'select number above cell']
["val-image-1.jpg", "select calendar option"],
["val-image-1.jpg", "select photos&videos option"],
["val-image-2.jpg", "click on change store"],
["example_2.jpg", "click on green color button"],
["example_2.jpg", "click on text which is beside call now"],
["example_2.jpg", "click on more button"],
["val-image-2.jpg", "click on shop menu at the bottom"],
["val-image-3.jpg", "click on image above short meow"],
["val-image-3.jpg", "go to cat sounds"],
]
demo = gr.Interface(fn=process_refexp,
inputs=[gr.Image(type="pil"), "text"],
outputs=[gr.Image(type="pil"), "json"],
title=title,
description=description,
article=article,
examples=examples,
# caching examples inference takes too long to start space after app change commit
cache_examples=False
)
demo.launch()