File size: 5,285 Bytes
a5153ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbbc9f1
 
a5153ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
import requests
from PIL import Image, ImageDraw, ImageFont
import random
from transformers import AutoProcessor, AutoModelForVision2Seq

# Load the model and processor
model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224")
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")


def draw_bounding_boxes(image: Image, entities):
    draw = ImageDraw.Draw(image)
    width, height = image.size
    
    color_bank = [
        "#0AC2FF", "#30D5C8", "#F3C300", "#47FF0A", "#C2FF0A"
    ]
    
    try:
        font_size = 20
        font = ImageFont.truetype("assets/arial.ttf", font_size)
    except IOError:
        font_size = 20
        font = ImageFont.load_default()

    for entity in entities:
        label, _, boxes = entity
        for box in boxes:
            box_coords = [
                box[0] * width, box[1] * height, 
                box[2] * width, box[3] * height
            ]
            
            outline_color = random.choice(color_bank)
            text_fill_color = random.choice(color_bank)
            
            draw.rectangle(box_coords, outline=outline_color, width=4)
            text_position = (box_coords[0] + 5, box_coords[1] - font_size - 5)
            draw.text(text_position, label, fill=text_fill_color, font=font)

    return image

def highlight_entities(text, entities):
    for entity in entities:
        label = entity[0]
        text = text.replace(label, f"*{label}*")  # Highlighting by enclosing in asterisks
    return text

def process_image(image, prompt_option, custom_prompt):
    if not isinstance(image, Image.Image):
        image = Image.open(image)

    # Use the selected prompt option
    if prompt_option == "Brief":
        prompt = "<grounding>An image of"
    elif prompt_option == "Detailed":
        prompt = "<grounding> Describe this image in detail:"
    else:  # Custom
        prompt = custom_prompt

    inputs = processor(text=prompt, images=image, return_tensors="pt")
    generated_ids = model.generate(
        pixel_values=inputs["pixel_values"],
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        image_embeds=None,
        image_embeds_position_mask=inputs["image_embeds_position_mask"],
        use_cache=True,
        max_new_tokens=128,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    processed_text, entities = processor.post_process_generation(generated_text)

    # Draw bounding boxes on a copy of the image
    processed_image = draw_bounding_boxes(image.copy(), entities)

    highlighted_entities = highlight_entities(processed_text, entities)

    return processed_image, processed_text, entities, highlighted_entities

def clear_interface():
    return None, None, None, None


with gr.Blocks(gr.themes.Soft()) as demo:
    gr.Markdown("# Kosmos-2 VQA Demo")
    gr.Markdown("Run this space on your own hardware with this command: ```docker run -it -p 7860:7860 --platform=linux/amd64 \
	registry.hf.space/macadeliccc-kosmos-2-demo:latest python app.py```")

    with gr.Row(equal_height=True):
        image_input = gr.Image(type="pil", label="Upload Image")
        processed_image_output = gr.Image(label="Processed Image")
    with gr.Row(equal_height=True):
        with gr.Column():
            with gr.Accordion("Prompt Options"):
                prompt_option = gr.Radio(choices=["Brief", "Detailed", "Custom"], label="Select Prompt Option", value="Brief")
                custom_prompt_input = gr.Textbox(label="Custom Prompt", visible=False)

                def show_custom_prompt_input(prompt_option):
                    return prompt_option == "Custom"

                prompt_option.change(show_custom_prompt_input, inputs=[prompt_option], outputs=[custom_prompt_input])

    with gr.Row(equal_height=True):
        submit_button = gr.Button("Run Model")
        clear_button = gr.Button("Clear", elem_id="clear_button")

    with gr.Row(equal_height=True):
        with gr.Column():
            highlighted_entities = gr.Textbox(label="Processed Text")
        with gr.Column():
            with gr.Accordion("Entities"):
                entities_output = gr.JSON(label="Entities", elem_id="entities_output")
    

    # Define examples
    examples = [
        ["assets/snowman.jpg", "Custom", "<grounding> Question: Where is<phrase> the fire</phrase><object><patch_index_0005><patch_index_0911></object> next to? Answer:"],
        ["assets/traffic.jpg", "Detailed", "<grounding> Describe this image in detail:"],
        ["assets/umbrellas.jpg", "Brief", "<grounding>An image of"],
    ]
    gr.Examples(examples, inputs=[image_input, prompt_option, custom_prompt_input])

    with gr.Row(equal_height=True):
        with gr.Accordion("Additional Info"):
            gr.Markdown("This demo uses the [Kosmos-2]")
    submit_button.click(
        fn=process_image, 
        inputs=[image_input, prompt_option, custom_prompt_input], 
        outputs=[processed_image_output, highlighted_entities, entities_output]
    )

    clear_button.click(
        fn=clear_interface,
        inputs=[],
        outputs=[image_input, processed_image_output, highlighted_entities, entities_output]
    )



demo.launch()