File size: 11,688 Bytes
5075711
42957be
4306537
b808114
01b6b5c
5075711
42957be
4306537
c902775
be9a691
 
89734ae
 
 
 
 
be9a691
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4306537
 
 
89734ae
 
4306537
6ff172b
23c94ef
4306537
 
 
6ff172b
4306537
 
 
23c94ef
 
 
 
4306537
 
 
23c94ef
 
 
 
 
4306537
 
23c94ef
 
 
 
f002e6a
 
23c94ef
 
 
d1a91c5
 
 
 
23c94ef
 
 
d1a91c5
23c94ef
f933462
f002e6a
 
23c94ef
 
 
 
d1a91c5
 
 
 
23c94ef
f002e6a
23c94ef
 
f002e6a
23c94ef
d1a91c5
 
 
 
3e7aee0
f002e6a
 
23c94ef
 
4306537
d1a91c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4306537
 
 
 
 
 
 
 
 
c7a140a
5075711
4306537
 
23c94ef
 
 
 
 
 
 
4306537
 
c7a140a
5075711
4306537
 
23c94ef
 
 
 
 
 
 
 
 
 
 
 
 
 
4306537
5075711
4306537
 
 
d1a91c5
 
4306537
d1a91c5
4306537
 
d1a91c5
4306537
 
 
f002e6a
 
 
 
 
 
 
 
 
 
 
 
9221c87
 
 
d1a91c5
9221c87
 
89734ae
d1a91c5
 
5cc7847
89734ae
 
f002e6a
89734ae
 
 
 
 
feb73a4
89734ae
 
 
 
f002e6a
89734ae
feb73a4
89734ae
 
 
 
 
 
feb73a4
89734ae
 
 
 
101e444
8d3e463
89734ae
 
 
f002e6a
89734ae
 
 
f002e6a
4306537
be9a691
 
23c94ef
 
 
4306537
 
 
 
c902775
5075711
 
337d5f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import gradio as gr
import torch
import numpy as np
import supervision as sv
from typing import Iterable
from transformers import (
    Qwen3VLForConditionalGeneration,
    Qwen3VLProcessor,
)
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
import json
import ast
import re
from PIL import Image
from spaces import GPU

colors.steel_blue = colors.Color(
    name="steel_blue",
    c50="#EBF3F8",
    c100="#D3E5F0",
    c200="#A8CCE1",
    c300="#7DB3D2",
    c400="#529AC3",
    c500="#4682B4",
    c600="#3E72A0",
    c700="#36638C",
    c800="#2E5378",
    c900="#264364",
    c950="#1E3450",
)

class SteelBlueTheme(Soft):
    def __init__(
        self,
        *,
        primary_hue: colors.Color | str = colors.gray,
        secondary_hue: colors.Color | str = colors.steel_blue,
        neutral_hue: colors.Color | str = colors.slate,
        text_size: sizes.Size | str = sizes.text_lg,
        font: fonts.Font | str | Iterable[fonts.Font | str] = (
            fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
        ),
        font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
            fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
        ),
    ):
        super().__init__(
            primary_hue=primary_hue,
            secondary_hue=secondary_hue,
            neutral_hue=neutral_hue,
            text_size=text_size,
            font=font,
            font_mono=font_mono,
        )
        super().set(
            background_fill_primary="*primary_50",
            background_fill_primary_dark="*primary_900",
            body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
            body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
            button_primary_text_color="white",
            button_primary_text_color_hover="white",
            button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
            button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
            button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
            button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
            button_secondary_text_color="black",
            button_secondary_text_color_hover="white",
            button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
            button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
            button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
            button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
            slider_color="*secondary_500",
            slider_color_dark="*secondary_600",
            block_title_text_weight="600",
            block_border_width="3px",
            block_shadow="*shadow_drop_lg",
            button_primary_shadow="*shadow_drop_lg",
            button_large_padding="11px",
            color_accent_soft="*primary_100",
            block_label_background_fill="*primary_200",
        )

steel_blue_theme = SteelBlueTheme()

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = "auto"

CATEGORIES = ["Query", "Caption", "Point", "Detect"]

qwen_model = Qwen3VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen3-VL-4B-Instruct",
    dtype=DTYPE,
    device_map=DEVICE,
).eval()
qwen_processor = Qwen3VLProcessor.from_pretrained(
    "Qwen/Qwen3-VL-4B-Instruct",
)

def safe_parse_json(text: str):
    text = text.strip()
    text = re.sub(r"^```(json)?", "", text)
    text = re.sub(r"```$", "", text)
    text = text.strip()
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass
    try:
        return ast.literal_eval(text)
    except Exception:
        return {}


def annotate_image(image: Image.Image, result: dict):
    if not isinstance(image, Image.Image) or not isinstance(result, dict):
        return image

    # Ensure image is mutable
    image = image.convert("RGB")
    original_width, original_height = image.size

    if "points" in result and result["points"]:
        points_list = [
            [int(p["x"] * original_width), int(p["y"] * original_height)]
            for p in result.get("points", [])
        ]
        if not points_list:
            return image

        points_array = np.array(points_list).reshape(1, -1, 2)
        key_points = sv.KeyPoints(xy=points_array)
        vertex_annotator = sv.VertexAnnotator(radius=4, color=sv.Color.RED)
        annotated_image = vertex_annotator.annotate(scene=np.array(image.copy()), key_points=key_points)
        return Image.fromarray(annotated_image)

    if "objects" in result and result["objects"]:
        boxes = []
        for obj in result["objects"]:
            x_min = obj.get("x_min", 0.0) * original_width
            y_min = obj.get("y_min", 0.0) * original_height
            x_max = obj.get("x_max", 0.0) * original_width
            y_max = obj.get("y_max", 0.0) * original_height
            boxes.append([x_min, y_min, x_max, y_max])

        if not boxes:
            return image

        detections = sv.Detections(xyxy=np.array(boxes))

        if len(detections) == 0:
            return image

        box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=2)
        annotated_image = box_annotator.annotate(scene=np.array(image.copy()), detections=detections)
        return Image.fromarray(annotated_image)

    return image

def run_qwen_inference(image: Image.Image, prompt: str):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt},
            ],
        }
    ]
    inputs = qwen_processor.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_dict=True,
        return_tensors="pt",
    ).to(DEVICE)

    with torch.inference_mode():
        generated_ids = qwen_model.generate(
            **inputs,
            max_new_tokens=512,
        )

    generated_ids_trimmed = [
        out_ids[len(in_ids) :]
        for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    return qwen_processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )[0]


@GPU
def process_qwen(image: Image.Image, category: str, prompt: str):
    if category == "Query":
        return run_qwen_inference(image, prompt), {}
    elif category == "Caption":
        full_prompt = f"Provide a {prompt} length caption for the image."
        return run_qwen_inference(image, full_prompt), {}
    elif category == "Point":
        full_prompt = (
            f"Provide 2d point coordinates for {prompt}. Report in JSON format."
        )
        output_text = run_qwen_inference(image, full_prompt)
        parsed_json = safe_parse_json(output_text)
        points_result = {"points": []}
        if isinstance(parsed_json, list):
            for item in parsed_json:
                if "point_2d" in item and len(item["point_2d"]) == 2:
                    x, y = item["point_2d"]
                    points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0})
        return json.dumps(points_result, indent=2), points_result
    elif category == "Detect":
        full_prompt = (
            f"Provide bounding box coordinates for {prompt}. Report in JSON format."
        )
        output_text = run_qwen_inference(image, full_prompt)
        parsed_json = safe_parse_json(output_text)
        objects_result = {"objects": []}
        if isinstance(parsed_json, list):
            for item in parsed_json:
                if "bbox_2d" in item and len(item["bbox_2d"]) == 4:
                    xmin, ymin, xmax, ymax = item["bbox_2d"]
                    objects_result["objects"].append(
                        {
                            "x_min": xmin / 1000.0,
                            "y_min": ymin / 1000.0,
                            "x_max": xmax / 1000.0,
                            "y_max": ymax / 1000.0,
                        }
                    )
        return json.dumps(objects_result, indent=2), objects_result
    return "Invalid category", {}

def process_inputs(image, category, prompt):
    if image is None:
        raise gr.Error("Please upload an image.")
    if not prompt:
        raise gr.Error("Please provide a prompt.")

    image.thumbnail((512, 512))

    qwen_text, qwen_data = process_qwen(image, category, prompt)
    qwen_annotated_image = annotate_image(image.copy(), qwen_data)

    return qwen_annotated_image, qwen_text

def on_category_change(category: str):
    if category == "Query":
        return gr.Textbox(placeholder="e.g., Count the total number of boats and describe the environment.")
    elif category == "Caption":
        return gr.Textbox(placeholder="e.g., short, normal, detailed")
    elif category == "Point":
        return gr.Textbox(placeholder="e.g., The gun held by the person.")
    elif category == "Detect":
        return gr.Textbox(placeholder="e.g., The headlight of the car.")
    return gr.Textbox(placeholder="e.g., detect the object.")


css = """
#main-title h1 {
    font-size: 2.3em !important;
}
#output-title h2 {
    font-size: 2.1em !important;
}
"""

with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# **Qwen-3VL: Multimodal Understanding**", elem_id="main-title")

        with gr.Row():
            with gr.Column(scale=1):
                image_input = gr.Image(type="pil", label="Upload Image")
                category_select = gr.Radio(
                    choices=CATEGORIES,
                    value="Query",
                    label="Select Task Category",
                    interactive=True,
                )
                prompt_input = gr.Textbox(
                    placeholder="e.g., Count the total number of boats and describe the environment.",
                    label="Prompt",
                    lines=1,
                )
                submit_btn = gr.Button("Process Image", variant="primary")

            with gr.Column(scale=2):
                qwen_img_output = gr.Image(label="Output Image")
                qwen_text_output = gr.Textbox(
                    label="Text Output", lines=10, interactive=False, show_copy_button=True
                )

        gr.Examples(
            examples=[
                ["examples/5.jpg", "Point", "Detect the children who are out of focus and wearing a white T-shirt."],
                ["examples/5.jpg", "Detect", "Point out the out-of-focus (all) children."],
                ["examples/4.jpg", "Detect", "Headlight"],
                ["examples/3.jpg", "Point", "Gun"],
                ["examples/1.jpg", "Query", "Count the total number of boats and describe the environment."],
                ["examples/2.jpg", "Caption", "a brief"],
            ],
            inputs=[image_input, category_select, prompt_input],
        )

    category_select.change(
        fn=on_category_change,
        inputs=[category_select],
        outputs=[prompt_input],
    )

    submit_btn.click(
        fn=process_inputs,
        inputs=[image_input, category_select, prompt_input],
        outputs=[qwen_img_output, qwen_text_output],
    )

if __name__ == "__main__":
    demo.launch(mcp_server=True, ssr_mode=False, show_error=True)