prithivMLmods commited on
Commit
4306537
·
verified ·
1 Parent(s): 42957be

update app

Browse files
Files changed (1) hide show
  1. app.py +236 -295
app.py CHANGED
@@ -1,319 +1,260 @@
1
- import spaces
2
- import json
3
- import math
4
- import os
5
- import traceback
6
- from io import BytesIO
7
- from typing import Any, Dict, List, Optional, Tuple, Iterable
8
- import re
9
- import time
10
- from threading import Thread
11
- from io import BytesIO
12
- import uuid
13
- import tempfile
14
-
15
  import gradio as gr
16
- import numpy as np
17
  import torch
18
- from PIL import Image
19
  import supervision as sv
20
-
21
  from transformers import (
22
  Qwen3VLForConditionalGeneration,
23
- AutoModelForCausalLM,
24
- AutoProcessor,
25
- )
26
- from gradio.themes import Soft
27
- from gradio.themes.utils import colors, fonts, sizes
28
-
29
- # --- Theme and CSS Definition ---
30
-
31
- # Define the SteelBlue color palette
32
- colors.steel_blue = colors.Color(
33
- name="steel_blue",
34
- c50="#EBF3F8",
35
- c100="#D3E5F0",
36
- c200="#A8CCE1",
37
- c300="#7DB3D2",
38
- c400="#529AC3",
39
- c500="#4682B4", # SteelBlue base color
40
- c600="#3E72A0",
41
- c700="#36638C",
42
- c800="#2E5378",
43
- c900="#264364",
44
- c950="#1E3450",
45
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- class SteelBlueTheme(Soft):
49
- def __init__(
50
- self,
51
- *,
52
- primary_hue: colors.Color | str = colors.gray,
53
- secondary_hue: colors.Color | str = colors.steel_blue,
54
- neutral_hue: colors.Color | str = colors.slate,
55
- text_size: sizes.Size | str = sizes.text_lg,
56
- font: fonts.Font | str | Iterable[fonts.Font | str] = (
57
- fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
58
- ),
59
- font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
60
- fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
61
- ),
62
- ):
63
- super().__init__(
64
- primary_hue=primary_hue,
65
- secondary_hue=secondary_hue,
66
- neutral_hue=neutral_hue,
67
- text_size=text_size,
68
- font=font,
69
- font_mono=font_mono,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  )
71
- super().set(
72
- background_fill_primary="*primary_50",
73
- background_fill_primary_dark="*primary_900",
74
- body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
75
- body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
76
- button_primary_text_color="white",
77
- button_primary_text_color_hover="white",
78
- button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
79
- button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
80
- button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
81
- button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
82
- slider_color="*secondary_500",
83
- slider_color_dark="*secondary_600",
84
- block_title_text_weight="600",
85
- block_border_width="3px",
86
- block_shadow="*shadow_drop_lg",
87
- button_primary_shadow="*shadow_drop_lg",
88
- button_large_padding="11px",
89
- color_accent_soft="*primary_100",
90
- block_label_background_fill="*primary_200",
91
  )
 
 
 
 
 
 
 
92
 
93
- # Instantiate the new theme
94
- steel_blue_theme = SteelBlueTheme()
95
-
96
- css = """
97
- #main-title h1 {
98
- font-size: 2.3em !important;
99
- }
100
- #output-title h2 {
101
- font-size: 2.1em !important;
102
- }
103
- """
104
-
105
-
106
- # --- Constants and Model Setup ---
107
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
-
109
- print("--- System Information ---")
110
- print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
111
- print("torch.__version__ =", torch.__version__)
112
- print("torch.version.cuda =", torch.version.cuda)
113
- print("CUDA available:", torch.cuda.is_available())
114
- print("CUDA device count:", torch.cuda.device_count())
115
- if torch.cuda.is_available():
116
- print("Current device:", torch.cuda.current_device())
117
- print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
118
- print("Using device:", device)
119
- print("--------------------------")
120
-
121
-
122
- # --- Model Loading ---
123
-
124
- # Load moondream3
125
- print("Loading moondream3-preview...")
126
- MODEL_ID_MD3 = "Qwen/Qwen3-VL-32B-Instruct"
127
- model_md3 = Qwen3VLForConditionalGeneration.from_pretrained(
128
- MODEL_ID_MD3,
129
- trust_remote_code=True,
130
- torch_dtype=torch.bfloat16,
131
- device_map={"": "cuda"},
132
- )
133
- model_md3.compile()
134
- print("moondream3-preview loaded and compiled.")
135
-
136
-
137
- # --- Moondream3 Utility Functions ---
138
-
139
- def create_annotated_image(image, detection_result, object_name="Object"):
140
- if not isinstance(detection_result, dict) or "objects" not in detection_result:
141
- return image
142
-
143
- original_width, original_height = image.size
144
- annotated_image = np.array(image.convert("RGB"))
145
 
146
- bboxes = []
147
- labels = []
 
 
148
 
149
- for i, obj in enumerate(detection_result["objects"]):
150
- x_min = int(obj["x_min"] * original_width)
151
- y_min = int(obj["y_min"] * original_height)
152
- x_max = int(obj["x_max"] * original_width)
153
- y_max = int(obj["y_max"] * original_height)
154
 
155
- x_min = max(0, min(x_min, original_width))
156
- y_min = max(0, min(y_min, original_height))
157
- x_max = max(0, min(x_max, original_width))
158
- y_max = max(0, min(y_max, original_height))
159
 
160
- if x_max > x_min and y_max > y_min:
161
- bboxes.append([x_min, y_min, x_max, y_max])
162
- labels.append(f"{object_name} {i+1}")
163
 
164
- if not bboxes:
165
- return image
166
 
167
- detections = sv.Detections(
168
- xyxy=np.array(bboxes, dtype=np.float32),
169
- class_id=np.arange(len(bboxes))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  )
171
 
172
- bounding_box_annotator = sv.BoxAnnotator(
173
- thickness=3,
174
- color_lookup=sv.ColorLookup.INDEX
 
 
175
  )
176
- label_annotator = sv.LabelAnnotator(
177
- text_thickness=2,
178
- text_scale=0.6,
179
- color_lookup=sv.ColorLookup.INDEX
180
  )
181
-
182
- annotated_image = bounding_box_annotator.annotate(
183
- scene=annotated_image, detections=detections
 
 
184
  )
185
- annotated_image = label_annotator.annotate(
186
- scene=annotated_image, detections=detections, labels=labels
187
- )
188
-
189
- return Image.fromarray(annotated_image)
190
-
191
- def create_point_annotated_image(image, point_result):
192
- if not isinstance(point_result, dict) or "points" not in point_result:
193
- return image
194
-
195
- original_width, original_height = image.size
196
- annotated_image = np.array(image.convert("RGB"))
197
-
198
- points = []
199
- for point in point_result["points"]:
200
- x = int(point["x"] * original_width)
201
- y = int(point["y"] * original_height)
202
- points.append([x, y])
203
-
204
- if points:
205
- points_array = np.array(points).reshape(1, -1, 2)
206
- key_points = sv.KeyPoints(xy=points_array)
207
- vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
208
- annotated_image = vertex_annotator.annotate(
209
- scene=annotated_image, key_points=key_points
210
- )
211
-
212
- return Image.fromarray(annotated_image)
213
-
214
- @spaces.GPU()
215
- def detect_objects_md3(image, prompt, task_type, max_objects):
216
- STANDARD_SIZE = (1024, 1024)
217
- if image is None:
218
- raise gr.Error("Please upload an image.")
219
- image.thumbnail(STANDARD_SIZE)
220
-
221
- t0 = time.perf_counter()
222
-
223
- if task_type == "Object Detection":
224
- settings = {"max_objects": max_objects} if max_objects > 0 else {}
225
- result = model_md3.detect(image, prompt, settings=settings)
226
- annotated_image = create_annotated_image(image, result, prompt)
227
- elif task_type == "Point Detection":
228
- result = model_md3.point(image, prompt)
229
- annotated_image = create_point_annotated_image(image, result)
230
- elif task_type == "Caption":
231
- result = model_md3.caption(image, length="normal")
232
- annotated_image = image
233
- else:
234
- result = model_md3.query(image=image, question=prompt, reasoning=True)
235
- annotated_image = image
236
-
237
- elapsed_ms = (time.perf_counter() - t0) * 1_000
238
-
239
- if isinstance(result, dict):
240
- if "objects" in result:
241
- output_text = f"Found {len(result['objects'])} objects:\n"
242
- for i, obj in enumerate(result['objects'], 1):
243
- output_text += f"\n{i}. Bounding box: ({obj['x_min']:.3f}, {obj['y_min']:.3f}, {obj['x_max']:.3f}, {obj['y_max']:.3f})"
244
- elif "points" in result:
245
- output_text = f"Found {len(result['points'])} points:\n"
246
- for i, point in enumerate(result['points'], 1):
247
- output_text += f"\n{i}. Point: ({point['x']:.3f}, {point['y']:.3f})"
248
- elif "caption" in result:
249
- output_text = result['caption']
250
- elif "answer" in result:
251
- output_text = f"Reasoning: {result.get('reasoning', 'N/A')}\n\nAnswer: {result['answer']}"
252
- else:
253
- output_text = json.dumps(result, indent=2)
254
- else:
255
- output_text = str(result)
256
-
257
- timing_text = f"Inference time: {elapsed_ms:.0f} ms"
258
-
259
- return annotated_image, output_text, timing_text
260
-
261
-
262
- # --- Gradio Interface ---
263
-
264
- def create_gradio_interface():
265
- """Builds and returns the Gradio web interface."""
266
-
267
- with gr.Blocks(theme=steel_blue_theme, css=css) as demo:
268
- gr.Markdown("# **🌝 Moondream3 Lab**", elem_id="main-title")
269
- gr.Markdown("Explore the capabilities of the Moondream3 Vision Language Model for tasks like Object/Point Detection, VQA, and Captioning.")
270
-
271
- with gr.Row():
272
- with gr.Column(scale=1):
273
- md3_image_input = gr.Image(label="Upload an image", type="pil", height=400)
274
- md3_task_type = gr.Radio(
275
- choices=["Object Detection", "Point Detection", "Caption", "Visual Question Answering"],
276
- label="Task Type", value="Object Detection"
277
- )
278
- md3_prompt_input = gr.Textbox(
279
- label="Prompt (object to detect/question to ask)",
280
- placeholder="e.g., 'car', 'person', 'What's in this image?'"
281
- )
282
- md3_max_objects = gr.Number(
283
- label="Max Objects (for Object Detection only)",
284
- value=10, minimum=1, maximum=50, step=1, visible=True
285
- )
286
- md3_generate_btn = gr.Button(value="Submit", variant="primary")
287
- with gr.Column(scale=1):
288
- md3_output_image = gr.Image(type="pil", label="Result", height=400)
289
- md3_output_textbox = gr.Textbox(label="Model Response", lines=10, show_copy_button=True)
290
- md3_output_time = gr.Markdown()
291
-
292
- gr.Examples(
293
- examples=[
294
- ["md3/1.jpg", "Object Detection", "boats", 7],
295
- ["md3/2.jpg", "Point Detection", "children", 7],
296
- ["md3/3.png", "Caption", "", 5],
297
- ["md3/4.jpeg", "Visual Question Answering", "Analyze the GDP trend over the years.", 5],
298
- ],
299
- inputs=[md3_image_input, md3_task_type, md3_prompt_input, md3_max_objects],
300
- label="Click an example to populate inputs"
301
- )
302
-
303
- # Event listeners for the interface
304
- def update_max_objects_visibility(task):
305
- return gr.update(visible=(task == "Object Detection"))
306
-
307
- md3_task_type.change(fn=update_max_objects_visibility, inputs=[md3_task_type], outputs=[md3_max_objects])
308
-
309
- md3_generate_btn.click(
310
- fn=detect_objects_md3,
311
- inputs=[md3_image_input, md3_prompt_input, md3_task_type, md3_max_objects],
312
- outputs=[md3_output_image, md3_output_textbox, md3_output_time]
313
- )
314
-
315
- return demo
316
 
317
  if __name__ == "__main__":
318
- demo = create_gradio_interface()
319
- demo.queue(max_size=50).launch(ssr_mode=False, mcp_server=True, show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from gradio.themes.ocean import Ocean
3
  import torch
4
+ import numpy as np
5
  import supervision as sv
 
6
  from transformers import (
7
  Qwen3VLForConditionalGeneration,
8
+ Qwen3VLProcessor,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  )
10
+ import json
11
+ import ast
12
+ import re
13
+ from PIL import Image
14
+ from spaces import GPU
15
+
16
+ # --- Constants and Configuration ---
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ DTYPE = "auto"
19
+
20
+ CATEGORIES = ["Query", "Caption", "Point", "Detect"]
21
+ PLACEHOLDERS = {
22
+ "Query": "What is in this image?",
23
+ "Caption": "Select a caption length from the suggestions below.",
24
+ "Point": "Select an object from suggestions or enter a custom one.",
25
+ "Detect": "Select an object from suggestions or enter a custom one.",
26
+ }
27
 
28
+ qwen_model = Qwen3VLForConditionalGeneration.from_pretrained(
29
+ "Qwen/Qwen3-VL-32B-Instruct",
30
+ torch_dtype=DTYPE,
31
+ device_map=DEVICE,
32
+ ).eval()
33
+ qwen_processor = Qwen3VLProcessor.from_pretrained(
34
+ "Qwen/Qwen3-VL-32B-Instruct",
35
+ )
36
+ print("Model loaded successfully.")
37
+
38
+
39
+ # --- Utility Functions ---
40
+ def safe_parse_json(text: str):
41
+ """Safely parse JSON or Python literal from a string, cleaning it first."""
42
+ # Find the JSON object within the text
43
+ match = re.search(r'\{.*\}', text, re.DOTALL)
44
+ if not match:
45
+ return {}
46
+ text = match.group(0)
47
+ try:
48
+ return json.loads(text)
49
+ except json.JSONDecodeError:
50
+ try:
51
+ # Fallback for Python dictionary literals
52
+ return ast.literal_eval(text)
53
+ except (ValueError, SyntaxError):
54
+ return {}
55
+
56
+
57
+ def annotate_image(image: Image.Image, result: dict, category: str):
58
+ """Draws annotations on the image based on the model's output."""
59
+ if not isinstance(image, Image.Image) or not isinstance(result, dict):
60
+ return image
61
 
62
+ image_np = np.array(image.convert("RGB"))
63
+
64
+ # Handle Point annotations
65
+ if category == "Point" and "points" in result and result["points"]:
66
+ points_xy = np.array(result["points"])
67
+ if points_xy.size == 0:
68
+ return image
69
+
70
+ # Denormalize points from [0, 1] range to image dimensions
71
+ points_xy *= np.array([image.width, image.height])
72
+
73
+ key_points = sv.KeyPoints(xy=points_xy.reshape(1, -1, 2))
74
+ annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
75
+ annotated_image = annotator.annotate(scene=image_np.copy(), key_points=key_points)
76
+ return Image.fromarray(annotated_image)
77
+
78
+ # Handle Detection annotations
79
+ if category == "Detect" and "objects" in result and result["objects"]:
80
+ boxes_xyxy = np.array(result["objects"])
81
+ if boxes_xyxy.size == 0:
82
+ return image
83
+
84
+ # Denormalize boxes from [0, 1] range to image dimensions
85
+ boxes_xyxy *= np.array([image.width, image.height, image.width, image.height])
86
+
87
+ detections = sv.Detections(xyxy=boxes_xyxy)
88
+ annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=4)
89
+ annotated_image = annotator.annotate(scene=image_np.copy(), detections=detections)
90
+ return Image.fromarray(annotated_image)
91
+
92
+ return image
93
+
94
+
95
+ # --- Inference Functions ---
96
+ def run_qwen_inference(image: Image.Image, prompt: str):
97
+ """Core function to run inference with the Qwen3-VL model."""
98
+ messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
99
+ inputs = qwen_processor.apply_chat_template(
100
+ messages,
101
+ tokenize=True,
102
+ add_generation_prompt=True,
103
+ return_dict=True,
104
+ return_tensors="pt",
105
+ ).to(DEVICE)
106
+
107
+ with torch.inference_mode():
108
+ generated_ids = qwen_model.generate(**inputs, max_new_tokens=512)
109
+
110
+ generated_ids_trimmed = generated_ids[:, inputs.input_ids.shape[1]:]
111
+ output_text = qwen_processor.batch_decode(
112
+ generated_ids_trimmed,
113
+ skip_special_tokens=True,
114
+ clean_up_tokenization_spaces=False,
115
+ )[0]
116
+ return output_text
117
+
118
+
119
+ @GPU
120
+ def get_suggested_objects(image: Image.Image):
121
+ """Get suggested objects in the image using Qwen3-VL to populate radio buttons."""
122
+ if image is None:
123
+ return gr.Radio(choices=[], visible=False)
124
+
125
+ try:
126
+ prompt = "List the 3 most prominent objects in this image as a Python list of strings. Example: ['car', 'tree', 'person']"
127
+ result_text = run_qwen_inference(image, prompt)
128
+
129
+ match = re.search(r'\[.*?\]', result_text)
130
+ if match:
131
+ suggestions = ast.literal_eval(match.group())
132
+ if isinstance(suggestions, list) and suggestions:
133
+ return gr.Radio(choices=suggestions, visible=True, interactive=True)
134
+ except Exception as e:
135
+ print(f"Error getting suggestions with Qwen: {e}")
136
+
137
+ return gr.Radio(choices=[], visible=False)
138
+
139
+
140
+ @GPU
141
+ def process_qwen(image: Image.Image, category: str, prompt: str):
142
+ """Process inputs based on the selected category, returning text and data for annotation."""
143
+ if category == "Query":
144
+ return run_qwen_inference(image, prompt), {}
145
+
146
+ elif category == "Caption":
147
+ full_prompt = f"Provide a {prompt} length caption for the image."
148
+ return run_qwen_inference(image, full_prompt), {}
149
+
150
+ elif category == "Point":
151
+ full_prompt = (
152
+ f"Provide 2D point coordinates for '{prompt}'. Respond ONLY with a JSON object like "
153
+ f"`{{\"points\": [[x1, y1], [x2, y2], ...]}}`. The coordinates must be normalized between 0.0 and 1.0."
154
  )
155
+ output_text = run_qwen_inference(image, full_prompt)
156
+ parsed_json = safe_parse_json(output_text)
157
+ # Ensure the parsed data has the correct structure
158
+ if "points" not in parsed_json or not isinstance(parsed_json["points"], list):
159
+ return output_text, {}
160
+ return output_text, parsed_json
161
+
162
+ elif category == "Detect":
163
+ full_prompt = (
164
+ f"Provide bounding box coordinates for '{prompt}'. Respond ONLY with a JSON object like "
165
+ f"`{{\"objects\": [[x_min, y_min, x_max, y_max], ...]}}`. The coordinates must be normalized between 0.0 and 1.0."
 
 
 
 
 
 
 
 
 
166
  )
167
+ output_text = run_qwen_inference(image, full_prompt)
168
+ parsed_json = safe_parse_json(output_text)
169
+ if "objects" not in parsed_json or not isinstance(parsed_json["objects"], list):
170
+ return output_text, {}
171
+ return output_text, parsed_json
172
+
173
+ return "Invalid category", {}
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
+ # --- Gradio Interface Logic ---
177
+ def on_category_and_image_change(image, category):
178
+ """Handle UI changes when the image or category is updated."""
179
+ text_box = gr.Textbox(value="", placeholder=PLACEHOLDERS.get(category, ""), interactive=True)
180
 
181
+ if category == "Caption":
182
+ return gr.Radio(choices=["short", "normal", "long"], value="normal", visible=True), text_box
 
 
 
183
 
184
+ if image is None or category not in ["Point", "Detect"]:
185
+ return gr.Radio(choices=[], visible=False), text_box
 
 
186
 
187
+ return get_suggested_objects(image), text_box
 
 
188
 
 
 
189
 
190
+ def process_inputs(image, category, prompt):
191
+ """Main function to handle the user's submission."""
192
+ if image is None:
193
+ raise gr.Error("Please upload an image.")
194
+ if not prompt and category not in ["Caption"]:
195
+ raise gr.Error("Please provide a prompt or select a suggestion.")
196
+ if category == "Caption" and not prompt:
197
+ prompt = "normal" # Default caption length
198
+
199
+ image.thumbnail((1024, 1024)) # Resize for faster inference
200
+
201
+ qwen_text, qwen_data = process_qwen(image, category, prompt)
202
+ qwen_annotated_image = annotate_image(image, qwen_data, category)
203
+
204
+ return qwen_annotated_image, qwen_text
205
+
206
+
207
+ # --- Gradio UI Layout ---
208
+ with gr.Blocks(theme=Ocean()) as demo:
209
+ gr.Markdown("# 👓 Object Understanding with Qwen3-VL")
210
+ gr.Markdown("### Explore object detection, keypoint detection, and captioning using natural language prompts.")
211
+
212
+ with gr.Row():
213
+ with gr.Column(scale=1):
214
+ image_input = gr.Image(type="pil", label="Input Image")
215
+ category_select = gr.Radio(
216
+ choices=CATEGORIES, value=CATEGORIES[0], label="Select Task", interactive=True
217
+ )
218
+ suggestions_radio = gr.Radio(
219
+ choices=[], label="Suggestions", visible=False, interactive=True
220
+ )
221
+ prompt_input = gr.Textbox(
222
+ placeholder=PLACEHOLDERS[CATEGORIES[0]], label="Prompt", lines=2
223
+ )
224
+ submit_btn = gr.Button("Generate", variant="primary")
225
+
226
+ with gr.Column(scale=2):
227
+ gr.Markdown("### Qwen/Qwen3-VL-4B-Instruct Output")
228
+ qwen_img_output = gr.Image(label="Annotated Image")
229
+ qwen_text_output = gr.Textbox(label="Text Output", lines=8, interactive=False, show_copy_button=True)
230
+
231
+ gr.Examples(
232
+ examples=[
233
+ ["examples/cars.jpg", "Query", "How many cars are in the image?"],
234
+ ["examples/dog_beach.jpg", "Detect", "dog"],
235
+ ["examples/person_skiing.jpg", "Point", "the person's head"],
236
+ ["examples/dog_beach.jpg", "Caption", "short"],
237
+ ],
238
+ inputs=[image_input, category_select, prompt_input],
239
  )
240
 
241
+ # --- Event Listeners ---
242
+ category_select.change(
243
+ fn=on_category_and_image_change,
244
+ inputs=[image_input, category_select],
245
+ outputs=[suggestions_radio, prompt_input],
246
  )
247
+ image_input.change(
248
+ fn=on_category_and_image_change,
249
+ inputs=[image_input, category_select],
250
+ outputs=[suggestions_radio, prompt_input],
251
  )
252
+ suggestions_radio.change(fn=lambda x: x, inputs=suggestions_radio, outputs=prompt_input)
253
+ submit_btn.click(
254
+ fn=process_inputs,
255
+ inputs=[image_input, category_select, prompt_input],
256
+ outputs=[qwen_img_output, qwen_text_output],
257
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  if __name__ == "__main__":
260
+ demo.launch()