prithivMLmods commited on
Commit
42957be
·
verified ·
1 Parent(s): eec5a19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -508
app.py CHANGED
@@ -1,59 +1,56 @@
1
- import os
2
- import random
3
- import uuid
4
  import json
5
- import time
6
- import asyncio
 
 
 
7
  import re
 
8
  from threading import Thread
9
- from pathlib import Path
10
  from io import BytesIO
11
- from typing import Optional, Tuple, Dict, Any, Iterable
 
12
 
13
  import gradio as gr
14
- import spaces
15
- import torch
16
  import numpy as np
 
17
  from PIL import Image
18
- import cv2
19
- import requests
20
- import fitz
21
  import supervision as sv
22
 
23
  from transformers import (
24
- Qwen3VLMoeForConditionalGeneration,
 
25
  AutoProcessor,
26
- TextIteratorStreamer,
27
  )
28
- from transformers.image_utils import load_image
29
-
30
  from gradio.themes import Soft
31
  from gradio.themes.utils import colors, fonts, sizes
32
 
33
  # --- Theme and CSS Definition ---
34
 
35
- # Define the new OrangeRed color palette
36
- colors.orange_red = colors.Color(
37
- name="orange_red",
38
- c50="#FFF0E5",
39
- c100="#FFE0CC",
40
- c200="#FFC299",
41
- c300="#FFA366",
42
- c400="#FF8533",
43
- c500="#FF4500", # OrangeRed base color
44
- c600="#E63E00",
45
- c700="#CC3700",
46
- c800="#B33000",
47
- c900="#992900",
48
- c950="#802200",
49
  )
50
 
51
- class OrangeRedTheme(Soft):
 
52
  def __init__(
53
  self,
54
  *,
55
  primary_hue: colors.Color | str = colors.gray,
56
- secondary_hue: colors.Color | str = colors.orange_red, # Use the new color
57
  neutral_hue: colors.Color | str = colors.slate,
58
  text_size: sizes.Size | str = sizes.text_lg,
59
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
@@ -82,12 +79,6 @@ class OrangeRedTheme(Soft):
82
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
83
  button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
84
  button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
85
- button_secondary_text_color="black",
86
- button_secondary_text_color_hover="white",
87
- button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
88
- button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
89
- button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
90
- button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
91
  slider_color="*secondary_500",
92
  slider_color_dark="*secondary_600",
93
  block_title_text_weight="600",
@@ -100,7 +91,7 @@ class OrangeRedTheme(Soft):
100
  )
101
 
102
  # Instantiate the new theme
103
- orange_red_theme = OrangeRedTheme()
104
 
105
  css = """
106
  #main-title h1 {
@@ -109,497 +100,220 @@ css = """
109
  #output-title h2 {
110
  font-size: 2.1em !important;
111
  }
112
- :root {
113
- --color-grey-50: #f9fafb;
114
- --banner-background: var(--secondary-400);
115
- --banner-text-color: var(--primary-100);
116
- --banner-background-dark: var(--secondary-800);
117
- --banner-text-color-dark: var(--primary-100);
118
- --banner-chrome-height: calc(16px + 43px);
119
- --chat-chrome-height-wide-no-banner: 320px;
120
- --chat-chrome-height-narrow-no-banner: 450px;
121
- --chat-chrome-height-wide: calc(var(--chat-chrome-height-wide-no-banner) + var(--banner-chrome-height));
122
- --chat-chrome-height-narrow: calc(var(--chat-chrome-height-narrow-no-banner) + var(--banner-chrome-height));
123
- }
124
- .banner-message { background-color: var(--banner-background); padding: 5px; margin: 0; border-radius: 5px; border: none; }
125
- .banner-message-text { font-size: 13px; font-weight: bolder; color: var(--banner-text-color) !important; }
126
- body.dark .banner-message { background-color: var(--banner-background-dark) !important; }
127
- body.dark .gradio-container .contain .banner-message .banner-message-text { color: var(--banner-text-color-dark) !important; }
128
- .toast-body { background-color: var(--color-grey-50); }
129
- .html-container:has(.css-styles) { padding: 0; margin: 0; }
130
- .css-styles { height: 0; }
131
- .model-message { text-align: end; }
132
- .model-dropdown-container { display: flex; align-items: center; gap: 10px; padding: 0; }
133
- .user-input-container .multimodal-textbox{ border: none !important; }
134
- .control-button { height: 51px; }
135
- button.cancel { border: var(--button-border-width) solid var(--button-cancel-border-color); background: var(--button-cancel-background-fill); color: var(--button-cancel-text-color); box-shadow: var(--button-cancel-shadow); }
136
- button.cancel:hover, .cancel[disabled] { background: var(--button-cancel-background-fill-hover); color: var(--button-cancel-text-color-hover); }
137
- .opt-out-message { top: 8px; }
138
- .opt-out-message .html-container, .opt-out-checkbox label { font-size: 14px !important; padding: 0 !important; margin: 0 !important; color: var(--neutral-400) !important; }
139
- div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-wide)) !important; max-height: 900px !important; }
140
- div.no-padding { padding: 0 !important; }
141
- @media (max-width: 1280px) { div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-wide)) !important; } }
142
- @media (max-width: 1024px) {
143
- .responsive-row { flex-direction: column; }
144
- .model-message { text-align: start; font-size: 10px !important; }
145
- .model-dropdown-container { flex-direction: column; align-items: flex-start; }
146
- div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-narrow)) !important; }
147
- }
148
- @media (max-width: 400px) {
149
- .responsive-row { flex-direction: column; }
150
- .model-message { text-align: start; font-size: 10px !important; }
151
- .model-dropdown-container { flex-direction: column; align-items: flex-start; }
152
- div.block.chatbot { max-height: 360px !important; }
153
- }
154
- @media (max-height: 932px) { .chatbot { max-height: 500px !important; } }
155
- @media (max-height: 1280px) { div.block.chatbot { max-height: 800px !important; } }
156
  """
157
 
158
- MAX_MAX_NEW_TOKENS = 4096
159
- DEFAULT_MAX_NEW_TOKENS = 1024
160
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
161
 
 
162
  print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
163
  print("torch.__version__ =", torch.__version__)
164
  print("torch.version.cuda =", torch.version.cuda)
165
- print("cuda available:", torch.cuda.is_available())
166
- print("cuda device count:", torch.cuda.device_count())
167
  if torch.cuda.is_available():
168
- print("current device:", torch.cuda.current_device())
169
- print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
170
  print("Using device:", device)
 
 
 
 
171
 
172
- MODEL_ID_Q3VL = "Qwen/Qwen3-VL-30B-A3B-Instruct"
173
- processor_q3vl = AutoProcessor.from_pretrained(MODEL_ID_Q3VL, trust_remote_code=True, use_fast=False)
174
- model_q3vl = Qwen3VLMoeForConditionalGeneration.from_pretrained(
175
- MODEL_ID_Q3VL,
 
176
  trust_remote_code=True,
177
- dtype=torch.float16
178
- ).to(device).eval()
179
-
180
- # --- Utility functions for Detection and Drawing ---
181
-
182
- def parse_detection_output(text: str) -> list:
183
- """Parses the model's text output to extract bounding boxes or points."""
184
- match = re.search(r'\[\s*\[.*?\]\s*\]', text)
185
- if not match:
186
- return []
187
- try:
188
- result = json.loads(match.group(0))
189
- if isinstance(result, list) and all(isinstance(item, list) for item in result):
190
- return result
191
- return []
192
- except (json.JSONDecodeError, TypeError):
193
- return []
194
-
195
- def draw_object_detections(image: Image.Image, detections: list, labels: list) -> Image.Image:
196
- """Draws bounding boxes on the image."""
197
- image_np = np.array(image.convert("RGB"))
198
- h, w, _ = image_np.shape
199
- boxes = []
200
- for box in detections:
201
- if len(box) == 4:
202
- x1, y1, x2, y2 = box
203
- boxes.append([x1 * w, y1 * h, x2 * w, y2 * h])
204
- if not boxes:
205
  return image
206
- detections_sv = sv.Detections(xyxy=np.array(boxes))
207
- bounding_box_annotator = sv.BoxAnnotator(thickness=2)
208
- label_annotator = sv.LabelAnnotator(text_thickness=1, text_scale=0.5)
209
- annotated_image = bounding_box_annotator.annotate(scene=image_np.copy(), detections=detections_sv)
210
- annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections_sv, labels=labels)
211
- return Image.fromarray(annotated_image)
212
 
213
- def draw_point_detections(image: Image.Image, points: list) -> Image.Image:
214
- """Draws points on the image."""
215
- image_np = np.array(image.convert("RGB"))
216
- h, w, _ = image_np.shape
217
- pts = []
218
- for point in points:
219
- if len(point) == 2:
220
- x, y = point
221
- pts.append([x * w, y * h])
222
- if not pts:
 
 
 
 
 
 
 
 
 
 
 
 
223
  return image
224
- points_np = np.array(pts).reshape(1, -1, 2)
225
- key_points = sv.KeyPoints(xy=points_np)
226
- point_annotator = sv.VertexAnnotator(radius=5, color=sv.Color.RED)
227
- annotated_image = point_annotator.annotate(scene=image_np.copy(), key_points=key_points)
228
- return Image.fromarray(annotated_image)
229
 
 
 
 
 
230
 
231
- # --- Core Generation Functions ---
232
-
233
- def extract_gif_frames(gif_path: str):
234
- if not gif_path:
235
- return []
236
- with Image.open(gif_path) as gif:
237
- total_frames = gif.n_frames
238
- frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int)
239
- frames = []
240
- for i in frame_indices:
241
- gif.seek(i)
242
- frames.append(gif.convert("RGB").copy())
243
- return frames
244
-
245
- def downsample_video(video_path):
246
- vidcap = cv2.VideoCapture(video_path)
247
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
248
- frames = []
249
- frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int)
250
- for i in frame_indices:
251
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
252
- success, image = vidcap.read()
253
- if success:
254
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
255
- pil_image = Image.fromarray(image)
256
- frames.append(pil_image)
257
- vidcap.release()
258
- return frames
259
-
260
- def convert_pdf_to_images(file_path: str, dpi: int = 200):
261
- if not file_path:
262
- return []
263
- images = []
264
- pdf_document = fitz.open(file_path)
265
- zoom = dpi / 72.0
266
- mat = fitz.Matrix(zoom, zoom)
267
- for page_num in range(len(pdf_document)):
268
- page = pdf_document.load_page(page_num)
269
- pix = page.get_pixmap(matrix=mat)
270
- img_data = pix.tobytes("png")
271
- images.append(Image.open(BytesIO(img_data)))
272
- pdf_document.close()
273
- return images
274
-
275
- def get_initial_pdf_state() -> Dict[str, Any]:
276
- return {"pages": [], "total_pages": 0, "current_page_index": 0}
277
-
278
- def load_and_preview_pdf(file_path: Optional[str]) -> Tuple[Optional[Image.Image], Dict[str, Any], str]:
279
- state = get_initial_pdf_state()
280
- if not file_path:
281
- return None, state, '<div style="text-align:center;">No file loaded</div>'
282
- try:
283
- pages = convert_pdf_to_images(file_path)
284
- if not pages:
285
- return None, state, '<div style="text-align:center;">Could not load file</div>'
286
- state["pages"] = pages
287
- state["total_pages"] = len(pages)
288
- page_info_html = f'<div style="text-align:center;">Page 1 / {state["total_pages"]}</div>'
289
- return pages[0], state, page_info_html
290
- except Exception as e:
291
- return None, state, f'<div style="text-align:center;">Failed to load preview: {e}</div>'
292
-
293
- def navigate_pdf_page(direction: str, state: Dict[str, Any]):
294
- if not state or not state["pages"]:
295
- return None, state, '<div style="text-align:center;">No file loaded</div>'
296
- current_index = state["current_page_index"]
297
- total_pages = state["total_pages"]
298
- if direction == "prev":
299
- new_index = max(0, current_index - 1)
300
- elif direction == "next":
301
- new_index = min(total_pages - 1, current_index + 1)
302
- else:
303
- new_index = current_index
304
- state["current_page_index"] = new_index
305
- image_preview = state["pages"][new_index]
306
- page_info_html = f'<div style="text-align:center;">Page {new_index + 1} / {total_pages}</div>'
307
- return image_preview, state, page_info_html
308
-
309
- @spaces.GPU
310
- def generate_image(text: str, image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
311
- if image is None:
312
- yield "Please upload an image.", "Please upload an image."
313
- return
314
- messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text}]}]
315
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
316
- inputs = processor_q3vl(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
317
- streamer = TextIteratorStreamer(processor_q3vl, skip_prompt=True, skip_special_tokens=True)
318
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
319
- thread = Thread(target=model_q3vl.generate, kwargs=generation_kwargs)
320
- thread.start()
321
- buffer = ""
322
- for new_text in streamer:
323
- buffer += new_text
324
- time.sleep(0.01)
325
- yield buffer, buffer
326
-
327
- @spaces.GPU
328
- def generate_video(text: str, video_path: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
329
- if video_path is None:
330
- yield "Please upload a video.", "Please upload a video."
331
- return
332
- frames = downsample_video(video_path)
333
- if not frames:
334
- yield "Could not process video.", "Could not process video."
335
- return
336
- messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
337
- for frame in frames:
338
- messages[0]["content"].insert(0, {"type": "image"})
339
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
340
- inputs = processor_q3vl(text=[prompt_full], images=frames, return_tensors="pt", padding=True).to(device)
341
- streamer = TextIteratorStreamer(processor_q3vl, skip_prompt=True, skip_special_tokens=True)
342
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty}
343
- thread = Thread(target=model_q3vl.generate, kwargs=generation_kwargs)
344
- thread.start()
345
- buffer = ""
346
- for new_text in streamer:
347
- buffer += new_text
348
- buffer = buffer.replace("<|im_end|>", "")
349
- time.sleep(0.01)
350
- yield buffer, buffer
351
-
352
- @spaces.GPU
353
- def generate_pdf(text: str, state: Dict[str, Any], max_new_tokens: int = 2048, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
354
- if not state or not state["pages"]:
355
- yield "Please upload a PDF file first.", "Please upload a PDF file first."
356
- return
357
- page_images = state["pages"]
358
- full_response = ""
359
- for i, image in enumerate(page_images):
360
- page_header = f"--- Page {i+1}/{len(page_images)} ---\n"
361
- yield full_response + page_header, full_response + page_header
362
- messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text}]}]
363
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
364
- inputs = processor_q3vl(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
365
- streamer = TextIteratorStreamer(processor_q3vl, skip_prompt=True, skip_special_tokens=True)
366
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
367
- thread = Thread(target=model_q3vl.generate, kwargs=generation_kwargs)
368
- thread.start()
369
- page_buffer = ""
370
- for new_text in streamer:
371
- page_buffer += new_text
372
- yield full_response + page_header + page_buffer, full_response + page_header + page_buffer
373
- time.sleep(0.01)
374
- full_response += page_header + page_buffer + "\n\n"
375
-
376
- @spaces.GPU
377
- def generate_caption(image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
378
- if image is None:
379
- yield "Please upload an image to caption.", "Please upload an image to caption."
380
- return
381
- system_prompt = (
382
- "You are an AI assistant that rigorously follows this response protocol: For every input image, your primary "
383
- "task is to write a precise caption that captures the essence of the image in clear, concise, and contextually "
384
- "accurate language. Along with the caption, provide a structured set of attributes describing the visual "
385
- "elements, including details such as objects, people, actions, colors, environment, mood, and other notable "
386
- "characteristics. Ensure captions are precise, neutral, and descriptive, avoiding unnecessary elaboration or "
387
- "subjective interpretation unless explicitly required. Do not reference the rules or instructions in the output; "
388
- "only return the formatted caption, attributes, and class_name."
389
  )
390
- messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": system_prompt}]}]
391
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
392
- inputs = processor_q3vl(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
393
- streamer = TextIteratorStreamer(processor_q3vl, skip_prompt=True, skip_special_tokens=True)
394
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
395
- thread = Thread(target=model_q3vl.generate, kwargs=generation_kwargs)
396
- thread.start()
397
- buffer = ""
398
- for new_text in streamer:
399
- buffer += new_text
400
- time.sleep(0.01)
401
- yield buffer, buffer
402
-
403
- @spaces.GPU
404
- def generate_gif(text: str, gif_path: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
405
- if gif_path is None:
406
- yield "Please upload a GIF.", "Please upload a GIF."
407
- return
408
- frames = extract_gif_frames(gif_path)
409
- if not frames:
410
- yield "Could not process GIF.", "Could not process GIF."
411
- return
412
- messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
413
- for frame in frames:
414
- messages[0]["content"].insert(0, {"type": "image"})
415
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
416
- inputs = processor_q3vl(text=[prompt_full], images=frames, return_tensors="pt", padding=True).to(device)
417
- streamer = TextIteratorStreamer(processor_q3vl, skip_prompt=True, skip_special_tokens=True)
418
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty}
419
- thread = Thread(target=model_q3vl.generate, kwargs=generation_kwargs)
420
- thread.start()
421
- buffer = ""
422
- for new_text in streamer:
423
- buffer += new_text
424
- buffer = buffer.replace("<|im_end|>", "")
425
- time.sleep(0.01)
426
- yield buffer, buffer
427
-
428
- @spaces.GPU
429
- def generate_detection(
430
- image: Image.Image, user_prompt: str, task_type: str, max_new_tokens: int = 256,
431
- temperature: float = 0.1, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2
432
- ):
433
  if image is None:
434
- return None, "Please upload an image."
435
- if not user_prompt:
436
- return image, "Please provide a prompt describing what to detect."
 
437
 
438
  if task_type == "Object Detection":
439
- system_prompt = (
440
- f"You are an expert object detector. Find all instances of '{user_prompt}' in the image. "
441
- "Respond ONLY with a Python list of bounding boxes in the format [[x_min, y_min, x_max, y_max], ...]. "
442
- "The coordinates must be normalized between 0.0 and 1.0."
443
- )
444
  elif task_type == "Point Detection":
445
- system_prompt = (
446
- f"You are an expert keypoint detector. Find the specific points for '{user_prompt}' in the image. "
447
- "Respond ONLY with a Python list of points in the format [[x, y], ...]. "
448
- "The coordinates must be normalized between 0.0 and 1.0."
449
- )
450
  else:
451
- return image, "Invalid task type specified."
452
-
453
- messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": system_prompt}]}]
454
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
455
- inputs = processor_q3vl(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
456
-
457
- generation_kwargs = {
458
- **inputs, "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature,
459
- "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty,
460
- }
461
-
462
- generate_ids = model_q3vl.generate(**generation_kwargs)
463
- generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
464
- response_text = processor_q3vl.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
465
-
466
- try:
467
- coords = parse_detection_output(response_text)
468
- if not coords:
469
- return image, f"Could not detect '{user_prompt}'.\nModel raw output:\n{response_text}"
470
- if task_type == "Object Detection":
471
- labels = [f"{user_prompt} #{i+1}" for i in range(len(coords))]
472
- annotated_image = draw_object_detections(image, coords, labels)
473
- else: # Point Detection
474
- annotated_image = draw_point_detections(image, coords)
475
- return annotated_image, response_text
476
- except Exception as e:
477
- return image, f"An error occurred during processing:\n{str(e)}\n\nModel raw output:\n{response_text}"
478
-
479
-
480
- image_examples = [["Perform OCR on the image...", "examples/images/1.jpg"],
481
- ["Caption the image. Describe the safety measures shown in the image. Conclude whether the situation is (safe or unsafe)...", "examples/images/2.jpg"],
482
- ["Solve the problem...", "examples/images/3.png"]]
483
- video_examples = [["Explain the Ad video in detail.", "examples/videos/1.mp4"],
484
- ["Explain the video in detail.", "examples/videos/2.mp4"]]
485
- pdf_examples = [["Extract the content precisely.", "examples/pdfs/doc1.pdf"],
486
- ["Analyze and provide a short report.", "examples/pdfs/doc2.pdf"]]
487
- gif_examples = [["Describe this GIF.", "examples/gifs/1.gif"],
488
- ["Describe this GIF.", "examples/gifs/2.gif"]]
489
- caption_examples = [["examples/captions/1.JPG"],
490
- ["examples/captions/2.jpeg"], ["examples/captions/3.jpeg"]]
491
- # NOTE: You'll need to create these example image files in a directory named 'examples/detection/'
492
- obj_det_examples = [["examples/detection/obj1.jpg", "the two people"], ["examples/detection/obj2.jpg", "the yellow taxi"]]
493
- point_det_examples = [["examples/detection/point1.jpg", "the eyes of the person"], ["examples/detection/point2.jpg", "the headlights of the car"]]
494
-
495
-
496
- with gr.Blocks(theme=orange_red_theme, css=css) as demo:
497
- pdf_state = gr.State(value=get_initial_pdf_state())
498
- gr.Markdown("# **Qwen-3VL:Multimodal**", elem_id="main-title")
499
- with gr.Row():
500
- with gr.Column(scale=2):
501
- with gr.Tabs():
502
- with gr.TabItem("Image Inference"):
503
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
504
- image_upload = gr.Image(type="pil", label="Upload Image", height=290)
505
- image_submit = gr.Button("Submit", variant="primary")
506
- gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
507
-
508
- with gr.TabItem("Video Inference"):
509
- video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
510
- video_upload = gr.Video(label="Upload Video(≤30s)", height=290)
511
- video_submit = gr.Button("Submit", variant="primary")
512
- gr.Examples(examples=video_examples, inputs=[video_query, video_upload])
513
-
514
- with gr.TabItem("PDF Inference"):
515
- with gr.Row():
516
- with gr.Column(scale=1):
517
- pdf_query = gr.Textbox(label="Query Input", placeholder="e.g., 'Summarize this document'")
518
- pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
519
- pdf_submit = gr.Button("Submit", variant="primary")
520
- with gr.Column(scale=1):
521
- pdf_preview_img = gr.Image(label="PDF Preview", height=290)
522
- with gr.Row():
523
- prev_page_btn = gr.Button("◀ Previous")
524
- page_info = gr.HTML('<div style="text-align:center;">No file loaded</div>')
525
- next_page_btn = gr.Button("Next ▶")
526
- gr.Examples(examples=pdf_examples, inputs=[pdf_query, pdf_upload])
527
-
528
- with gr.TabItem("Gif Inference"):
529
- gif_query = gr.Textbox(label="Query Input", placeholder="e.g., 'What is happening in this gif?'")
530
- gif_upload = gr.Image(type="filepath", label="Upload GIF", height=290)
531
- gif_submit = gr.Button("Submit", variant="primary")
532
- gr.Examples(examples=gif_examples, inputs=[gif_query, gif_upload])
533
-
534
- with gr.TabItem("Caption"):
535
- caption_image_upload = gr.Image(type="pil", label="Image to Caption", height=290)
536
- caption_submit = gr.Button("Generate Caption", variant="primary")
537
- gr.Examples(examples=caption_examples, inputs=[caption_image_upload])
538
-
539
- with gr.TabItem("Object Detection"):
540
- with gr.Row():
541
- with gr.Column(scale=1):
542
- obj_det_image_upload = gr.Image(type="pil", label="Upload Image", height=290)
543
- obj_det_query = gr.Textbox(label="Object to Detect", placeholder="e.g., car, person, dog")
544
- obj_det_submit = gr.Button("Detect Objects", variant="primary")
545
- with gr.Column(scale=1):
546
- obj_det_output_image = gr.Image(type="pil", label="Detection Result", height=290)
547
- obj_det_output_text = gr.Textbox(label="Model Raw Output", interactive=False, lines=5)
548
- gr.Examples(examples=obj_det_examples, inputs=[obj_det_image_upload, obj_det_query])
549
-
550
- with gr.TabItem("Point Detection"):
551
- with gr.Row():
552
- with gr.Column(scale=1):
553
- point_det_image_upload = gr.Image(type="pil", label="Upload Image", height=290)
554
- point_det_query = gr.Textbox(label="Point(s) to Detect", placeholder="e.g., the eyes of the cat")
555
- point_det_submit = gr.Button("Detect Points", variant="primary")
556
- with gr.Column(scale=1):
557
- point_det_output_image = gr.Image(type="pil", label="Detection Result", height=290)
558
- point_det_output_text = gr.Textbox(label="Model Raw Output", interactive=False, lines=5)
559
- gr.Examples(examples=point_det_examples, inputs=[point_det_image_upload, point_det_query])
560
-
561
- with gr.Accordion("Advanced options", open=False):
562
- max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
563
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
564
- top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
565
- top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
566
- repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
567
-
568
- with gr.Column(scale=3):
569
- gr.Markdown("## Output", elem_id="output-title")
570
- output = gr.Textbox(label="Raw Output Stream (General Tasks)", interactive=False, lines=20, show_copy_button=True)
571
- with gr.Accordion("(Result.md)", open=False):
572
- markdown_output = gr.Markdown(label="(Result.Md)", latex_delimiters=[
573
- {"left": "$$", "right": "$$", "display": True},
574
- {"left": "$", "right": "$", "display": False}
575
- ])
576
-
577
- # Click handlers for original tabs
578
- image_submit.click(fn=generate_image, inputs=[image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[output, markdown_output])
579
- video_submit.click(fn=generate_video, inputs=[video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[output, markdown_output])
580
- pdf_submit.click(fn=generate_pdf, inputs=[pdf_query, pdf_state, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[output, markdown_output])
581
- gif_submit.click(fn=generate_gif, inputs=[gif_query, gif_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[output, markdown_output])
582
- caption_submit.click(fn=generate_caption, inputs=[caption_image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[output, markdown_output])
583
-
584
- # PDF navigation handlers
585
- pdf_upload.change(fn=load_and_preview_pdf, inputs=[pdf_upload], outputs=[pdf_preview_img, pdf_state, page_info])
586
- prev_page_btn.click(fn=lambda s: navigate_pdf_page("prev", s), inputs=[pdf_state], outputs=[pdf_preview_img, pdf_state, page_info])
587
- next_page_btn.click(fn=lambda s: navigate_pdf_page("next", s), inputs=[pdf_state], outputs=[pdf_preview_img, pdf_state, page_info])
588
-
589
- # Click handlers for NEW tabs
590
- obj_det_submit.click(
591
- fn=generate_detection,
592
- inputs=[obj_det_image_upload, obj_det_query, gr.Textbox(value="Object Detection", visible=False),
593
- max_new_tokens, temperature, top_p, top_k, repetition_penalty],
594
- outputs=[obj_det_output_image, obj_det_output_text]
595
- )
596
- point_det_submit.click(
597
- fn=generate_detection,
598
- inputs=[point_det_image_upload, point_det_query, gr.Textbox(value="Point Detection", visible=False),
599
- max_new_tokens, temperature, top_p, top_k, repetition_penalty],
600
- outputs=[point_det_output_image, point_det_output_text]
601
- )
602
 
 
603
 
604
  if __name__ == "__main__":
605
- demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True)
 
 
1
+ import spaces
 
 
2
  import json
3
+ import math
4
+ import os
5
+ import traceback
6
+ from io import BytesIO
7
+ from typing import Any, Dict, List, Optional, Tuple, Iterable
8
  import re
9
+ import time
10
  from threading import Thread
 
11
  from io import BytesIO
12
+ import uuid
13
+ import tempfile
14
 
15
  import gradio as gr
 
 
16
  import numpy as np
17
+ import torch
18
  from PIL import Image
 
 
 
19
  import supervision as sv
20
 
21
  from transformers import (
22
+ Qwen3VLForConditionalGeneration,
23
+ AutoModelForCausalLM,
24
  AutoProcessor,
 
25
  )
 
 
26
  from gradio.themes import Soft
27
  from gradio.themes.utils import colors, fonts, sizes
28
 
29
  # --- Theme and CSS Definition ---
30
 
31
+ # Define the SteelBlue color palette
32
+ colors.steel_blue = colors.Color(
33
+ name="steel_blue",
34
+ c50="#EBF3F8",
35
+ c100="#D3E5F0",
36
+ c200="#A8CCE1",
37
+ c300="#7DB3D2",
38
+ c400="#529AC3",
39
+ c500="#4682B4", # SteelBlue base color
40
+ c600="#3E72A0",
41
+ c700="#36638C",
42
+ c800="#2E5378",
43
+ c900="#264364",
44
+ c950="#1E3450",
45
  )
46
 
47
+
48
+ class SteelBlueTheme(Soft):
49
  def __init__(
50
  self,
51
  *,
52
  primary_hue: colors.Color | str = colors.gray,
53
+ secondary_hue: colors.Color | str = colors.steel_blue,
54
  neutral_hue: colors.Color | str = colors.slate,
55
  text_size: sizes.Size | str = sizes.text_lg,
56
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
 
79
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
80
  button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
81
  button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
 
 
 
 
 
 
82
  slider_color="*secondary_500",
83
  slider_color_dark="*secondary_600",
84
  block_title_text_weight="600",
 
91
  )
92
 
93
  # Instantiate the new theme
94
+ steel_blue_theme = SteelBlueTheme()
95
 
96
  css = """
97
  #main-title h1 {
 
100
  #output-title h2 {
101
  font-size: 2.1em !important;
102
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  """
104
 
105
+
106
+ # --- Constants and Model Setup ---
107
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
 
109
+ print("--- System Information ---")
110
  print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
111
  print("torch.__version__ =", torch.__version__)
112
  print("torch.version.cuda =", torch.version.cuda)
113
+ print("CUDA available:", torch.cuda.is_available())
114
+ print("CUDA device count:", torch.cuda.device_count())
115
  if torch.cuda.is_available():
116
+ print("Current device:", torch.cuda.current_device())
117
+ print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
118
  print("Using device:", device)
119
+ print("--------------------------")
120
+
121
+
122
+ # --- Model Loading ---
123
 
124
+ # Load moondream3
125
+ print("Loading moondream3-preview...")
126
+ MODEL_ID_MD3 = "Qwen/Qwen3-VL-32B-Instruct"
127
+ model_md3 = Qwen3VLForConditionalGeneration.from_pretrained(
128
+ MODEL_ID_MD3,
129
  trust_remote_code=True,
130
+ torch_dtype=torch.bfloat16,
131
+ device_map={"": "cuda"},
132
+ )
133
+ model_md3.compile()
134
+ print("moondream3-preview loaded and compiled.")
135
+
136
+
137
+ # --- Moondream3 Utility Functions ---
138
+
139
+ def create_annotated_image(image, detection_result, object_name="Object"):
140
+ if not isinstance(detection_result, dict) or "objects" not in detection_result:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  return image
 
 
 
 
 
 
142
 
143
+ original_width, original_height = image.size
144
+ annotated_image = np.array(image.convert("RGB"))
145
+
146
+ bboxes = []
147
+ labels = []
148
+
149
+ for i, obj in enumerate(detection_result["objects"]):
150
+ x_min = int(obj["x_min"] * original_width)
151
+ y_min = int(obj["y_min"] * original_height)
152
+ x_max = int(obj["x_max"] * original_width)
153
+ y_max = int(obj["y_max"] * original_height)
154
+
155
+ x_min = max(0, min(x_min, original_width))
156
+ y_min = max(0, min(y_min, original_height))
157
+ x_max = max(0, min(x_max, original_width))
158
+ y_max = max(0, min(y_max, original_height))
159
+
160
+ if x_max > x_min and y_max > y_min:
161
+ bboxes.append([x_min, y_min, x_max, y_max])
162
+ labels.append(f"{object_name} {i+1}")
163
+
164
+ if not bboxes:
165
  return image
 
 
 
 
 
166
 
167
+ detections = sv.Detections(
168
+ xyxy=np.array(bboxes, dtype=np.float32),
169
+ class_id=np.arange(len(bboxes))
170
+ )
171
 
172
+ bounding_box_annotator = sv.BoxAnnotator(
173
+ thickness=3,
174
+ color_lookup=sv.ColorLookup.INDEX
175
+ )
176
+ label_annotator = sv.LabelAnnotator(
177
+ text_thickness=2,
178
+ text_scale=0.6,
179
+ color_lookup=sv.ColorLookup.INDEX
180
+ )
181
+
182
+ annotated_image = bounding_box_annotator.annotate(
183
+ scene=annotated_image, detections=detections
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  )
185
+ annotated_image = label_annotator.annotate(
186
+ scene=annotated_image, detections=detections, labels=labels
187
+ )
188
+
189
+ return Image.fromarray(annotated_image)
190
+
191
+ def create_point_annotated_image(image, point_result):
192
+ if not isinstance(point_result, dict) or "points" not in point_result:
193
+ return image
194
+
195
+ original_width, original_height = image.size
196
+ annotated_image = np.array(image.convert("RGB"))
197
+
198
+ points = []
199
+ for point in point_result["points"]:
200
+ x = int(point["x"] * original_width)
201
+ y = int(point["y"] * original_height)
202
+ points.append([x, y])
203
+
204
+ if points:
205
+ points_array = np.array(points).reshape(1, -1, 2)
206
+ key_points = sv.KeyPoints(xy=points_array)
207
+ vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
208
+ annotated_image = vertex_annotator.annotate(
209
+ scene=annotated_image, key_points=key_points
210
+ )
211
+
212
+ return Image.fromarray(annotated_image)
213
+
214
+ @spaces.GPU()
215
+ def detect_objects_md3(image, prompt, task_type, max_objects):
216
+ STANDARD_SIZE = (1024, 1024)
 
 
 
 
 
 
 
 
 
 
 
217
  if image is None:
218
+ raise gr.Error("Please upload an image.")
219
+ image.thumbnail(STANDARD_SIZE)
220
+
221
+ t0 = time.perf_counter()
222
 
223
  if task_type == "Object Detection":
224
+ settings = {"max_objects": max_objects} if max_objects > 0 else {}
225
+ result = model_md3.detect(image, prompt, settings=settings)
226
+ annotated_image = create_annotated_image(image, result, prompt)
 
 
227
  elif task_type == "Point Detection":
228
+ result = model_md3.point(image, prompt)
229
+ annotated_image = create_point_annotated_image(image, result)
230
+ elif task_type == "Caption":
231
+ result = model_md3.caption(image, length="normal")
232
+ annotated_image = image
233
  else:
234
+ result = model_md3.query(image=image, question=prompt, reasoning=True)
235
+ annotated_image = image
236
+
237
+ elapsed_ms = (time.perf_counter() - t0) * 1_000
238
+
239
+ if isinstance(result, dict):
240
+ if "objects" in result:
241
+ output_text = f"Found {len(result['objects'])} objects:\n"
242
+ for i, obj in enumerate(result['objects'], 1):
243
+ output_text += f"\n{i}. Bounding box: ({obj['x_min']:.3f}, {obj['y_min']:.3f}, {obj['x_max']:.3f}, {obj['y_max']:.3f})"
244
+ elif "points" in result:
245
+ output_text = f"Found {len(result['points'])} points:\n"
246
+ for i, point in enumerate(result['points'], 1):
247
+ output_text += f"\n{i}. Point: ({point['x']:.3f}, {point['y']:.3f})"
248
+ elif "caption" in result:
249
+ output_text = result['caption']
250
+ elif "answer" in result:
251
+ output_text = f"Reasoning: {result.get('reasoning', 'N/A')}\n\nAnswer: {result['answer']}"
252
+ else:
253
+ output_text = json.dumps(result, indent=2)
254
+ else:
255
+ output_text = str(result)
256
+
257
+ timing_text = f"Inference time: {elapsed_ms:.0f} ms"
258
+
259
+ return annotated_image, output_text, timing_text
260
+
261
+
262
+ # --- Gradio Interface ---
263
+
264
+ def create_gradio_interface():
265
+ """Builds and returns the Gradio web interface."""
266
+
267
+ with gr.Blocks(theme=steel_blue_theme, css=css) as demo:
268
+ gr.Markdown("# **🌝 Moondream3 Lab**", elem_id="main-title")
269
+ gr.Markdown("Explore the capabilities of the Moondream3 Vision Language Model for tasks like Object/Point Detection, VQA, and Captioning.")
270
+
271
+ with gr.Row():
272
+ with gr.Column(scale=1):
273
+ md3_image_input = gr.Image(label="Upload an image", type="pil", height=400)
274
+ md3_task_type = gr.Radio(
275
+ choices=["Object Detection", "Point Detection", "Caption", "Visual Question Answering"],
276
+ label="Task Type", value="Object Detection"
277
+ )
278
+ md3_prompt_input = gr.Textbox(
279
+ label="Prompt (object to detect/question to ask)",
280
+ placeholder="e.g., 'car', 'person', 'What's in this image?'"
281
+ )
282
+ md3_max_objects = gr.Number(
283
+ label="Max Objects (for Object Detection only)",
284
+ value=10, minimum=1, maximum=50, step=1, visible=True
285
+ )
286
+ md3_generate_btn = gr.Button(value="Submit", variant="primary")
287
+ with gr.Column(scale=1):
288
+ md3_output_image = gr.Image(type="pil", label="Result", height=400)
289
+ md3_output_textbox = gr.Textbox(label="Model Response", lines=10, show_copy_button=True)
290
+ md3_output_time = gr.Markdown()
291
+
292
+ gr.Examples(
293
+ examples=[
294
+ ["md3/1.jpg", "Object Detection", "boats", 7],
295
+ ["md3/2.jpg", "Point Detection", "children", 7],
296
+ ["md3/3.png", "Caption", "", 5],
297
+ ["md3/4.jpeg", "Visual Question Answering", "Analyze the GDP trend over the years.", 5],
298
+ ],
299
+ inputs=[md3_image_input, md3_task_type, md3_prompt_input, md3_max_objects],
300
+ label="Click an example to populate inputs"
301
+ )
302
+
303
+ # Event listeners for the interface
304
+ def update_max_objects_visibility(task):
305
+ return gr.update(visible=(task == "Object Detection"))
306
+
307
+ md3_task_type.change(fn=update_max_objects_visibility, inputs=[md3_task_type], outputs=[md3_max_objects])
308
+
309
+ md3_generate_btn.click(
310
+ fn=detect_objects_md3,
311
+ inputs=[md3_image_input, md3_prompt_input, md3_task_type, md3_max_objects],
312
+ outputs=[md3_output_image, md3_output_textbox, md3_output_time]
313
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
+ return demo
316
 
317
  if __name__ == "__main__":
318
+ demo = create_gradio_interface()
319
+ demo.queue(max_size=50).launch(ssr_mode=False, mcp_server=True, show_error=True)