File size: 4,050 Bytes
7670816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd91985
7670816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import ImageDraw


device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

models = {
    "AskUI/PTA-1": AutoModelForCausalLM.from_pretrained("AskUI/PTA-1", trust_remote_code=True),
}

processors = {
    "AskUI/PTA-1": AutoProcessor.from_pretrained("AskUI/PTA-1", trust_remote_code=True)
}


def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=3):
    draw = ImageDraw.Draw(image)
    for box in bounding_boxes:
        xmin, ymin, xmax, ymax = box
        draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width)
    return image


def florence_output_to_box(output):
    try:
        if "polygons" in output and len(output["polygons"]) > 0:
            polygons = output["polygons"]
            target_polygon = polygons[0][0]
            target_polygon = [int(el) for el in target_polygon]
            return [
                target_polygon[0],
                target_polygon[1],
                target_polygon[4],
                target_polygon[5],
            ]
        if "bboxes" in output and len(output["bboxes"]) > 0:
            bboxes = output["bboxes"]
            target_bbox = bboxes[0]
            target_bbox = [int(el) for el in target_bbox]
            return target_bbox
    except Exception as e:
        print(f"Error: {e}")
    return None


def run_example(image, text_input, model_id="AskUI/PTA-1"):
    model = models[model_id].to(device, torch_dtype)
    processor = processors[model_id]
    task_prompt = "<OPEN_VOCABULARY_DETECTION>"
    prompt = task_prompt + text_input

    image = image.convert("RGB")

    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)

    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        do_sample=False,
        num_beams=3,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(generated_text, task="<OPEN_VOCABULARY_DETECTION>", image_size=(image.width, image.height))
    target_box = florence_output_to_box(parsed_answer["<OPEN_VOCABULARY_DETECTION>"])
    return target_box, draw_bounding_boxes(image, [target_box])


css = """
  #output {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""
with gr.Blocks(css=css) as demo:
    gr.Markdown(
    """
    # PTA-1: Controlling Computers with Small Models
    """)
    gr.Markdown("Check out the model [AskUI/PTA-1](https://huggingface.co/AskUI/PTA-1).")
    with gr.Row():
        with gr.Column():
            input_img = gr.Image(label="Input Image", type="pil")
            model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="AskUI/PTA-1")
            text_input = gr.Textbox(label="User Prompt")
            submit_btn = gr.Button(value="Submit")
        with gr.Column():
            model_output_text = gr.Textbox(label="Model Output Text")
            annotated_image = gr.Image(label="Annotated Image")

    gr.Examples(
        examples=[
            ["assets/sample.png", "search box"],
            ["assets/sample.png", "Query Service"],
            ["assets/ipad.png", "App Store icon"],
            ["assets/ipad.png", 'colorful icon with letter "S"'],
            ["assets/phone.jpg", "password field"],
            ["assets/phone.jpg", "back arrow icon"],
            ["assets/windows.jpg", "icon with letter S"],
            ["assets/windows.jpg", "Settings"],
        ],
        inputs=[input_img, text_input],
        outputs=[model_output_text, annotated_image],
        fn=run_example,
        cache_examples=False,
        label="Try examples"
    )

    submit_btn.click(run_example, [input_img, text_input, model_selector], [model_output_text, annotated_image])

demo.launch(debug=False)