File size: 4,450 Bytes
f0a0a4a
 
671732c
93085f8
f0a0a4a
238e0cb
f0a0a4a
 
 
 
 
 
 
 
 
 
 
fb4e118
186c0c1
a432919
186c0c1
ba6d9e2
 
 
f0a0a4a
a432919
f0a0a4a
 
72d0321
e0dd23e
f0a0a4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ea9225
f0a0a4a
 
 
 
2ea9225
 
e0dd23e
949256f
e0dd23e
f0a0a4a
a432919
 
 
e0dd23e
a43a5b0
e0dd23e
e8e6698
9184635
 
 
 
e0dd23e
 
 
 
 
 
 
a432919
e0dd23e
 
f0a0a4a
 
7d6913a
f0a0a4a
 
e1b2bb3
f37a3ce
 
 
161ad07
 
 
 
f37a3ce
ba6d9e2
a218a91
a432919
 
a218a91
 
 
3d33b8d
 
 
a218a91
3d33b8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import re
import gradio as gr
from PIL import Image, ImageDraw
import math
import torch
import html
from transformers import DonutProcessor, VisionEncoderDecoderModel

pretrained_repo_name = "ivelin/donut-refexp-draft"

processor = DonutProcessor.from_pretrained(pretrained_repo_name)
model = VisionEncoderDecoderModel.from_pretrained(pretrained_repo_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


def process_refexp(image: Image, prompt: str):

    print(f"(image, prompt): {image}, {prompt}")

    # trim prompt to 80 characters and normalize to lowercase
    prompt = prompt[:80].lower()

    # prepare encoder inputs
    pixel_values = processor(image, return_tensors="pt").pixel_values

    # prepare decoder inputs
    task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_bounding_box>"
    prompt = task_prompt.replace("{user_input}", prompt)
    decoder_input_ids = processor.tokenizer(
        prompt, add_special_tokens=False, return_tensors="pt").input_ids

    # generate answer
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # postprocess
    sequence = processor.batch_decode(outputs.sequences)[0]
    print(fr"predicted decoder sequence: {html.escape(sequence)}")
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
        processor.tokenizer.pad_token, "")
    # remove first task start token
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
    print(
        fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
    bbox = processor.token2json(sequence)
    bbox = bbox['target_bounding_box']
    print(f"predicted bounding box: {bbox}")

    print(f"image object: {image}")
    print(f"image size: {image.size}")
    width, height = image.size
    print(f"image width, height: {width, height}")
    print(f"processed prompt: {prompt}")

    # safeguard in case text prediction is missing some bounding box coordinates
    xmin = math.floor(width*float(bbox.get("xmin", 0)))
    ymin = math.floor(height*float(bbox.get("ymin", 0)))
    xmax = math.floor(width*float(bbox.get("xmax", 1)))
    ymax = math.floor(height*float(bbox.get("ymax", 1)))

    print(
        f"to image pixel values: xmin, ymin, xmax, ymax: {xmin, ymin, xmax, ymax}")

    shape = [(xmin, ymin), (xmax, ymax)]

    # create rectangle image
    img1 = ImageDraw.Draw(image)
    img1.rectangle(shape, outline="green", width=5)
    return image, bbox


title = "Demo: Donut 🍩 for UI RefExp"
description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on UIBert RefExp Dataset (UI Referring Expression). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
examples = [["example_1.jpg", "select the setting icon from top right corner"],
            ["example_1.jpg", "click on down arrow beside the entertainment"],
            ["example_1.jpg", "select the down arrow button beside lifestyle"],
            ["example_1.jpg", "click on the image beside the option traffic"],
            ["example_2.jpg", "enter the text field next to the name"],
            ["example_2.jpg", "click on green color button"],
            ["example_2.jpg", "click on text which is beside call now"],
            ["example_2.jpg", "click on more button"],
            ]

demo = gr.Interface(fn=process_refexp,
                    inputs=[gr.Image(type="pil"), "text"],
                    outputs=[gr.Image(type="pil"), "json"],
                    title=title,
                    description=description,
                    article=article,
                    examples=examples,
                    cache_examples=True
                    )

demo.launch()