Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,688 Bytes
5075711 42957be 4306537 b808114 01b6b5c 5075711 42957be 4306537 c902775 be9a691 89734ae be9a691 4306537 89734ae 4306537 6ff172b 23c94ef 4306537 6ff172b 4306537 23c94ef 4306537 23c94ef 4306537 23c94ef f002e6a 23c94ef d1a91c5 23c94ef d1a91c5 23c94ef f933462 f002e6a 23c94ef d1a91c5 23c94ef f002e6a 23c94ef f002e6a 23c94ef d1a91c5 3e7aee0 f002e6a 23c94ef 4306537 d1a91c5 4306537 c7a140a 5075711 4306537 23c94ef 4306537 c7a140a 5075711 4306537 23c94ef 4306537 5075711 4306537 d1a91c5 4306537 d1a91c5 4306537 d1a91c5 4306537 f002e6a 9221c87 d1a91c5 9221c87 89734ae d1a91c5 5cc7847 89734ae f002e6a 89734ae feb73a4 89734ae f002e6a 89734ae feb73a4 89734ae feb73a4 89734ae 101e444 8d3e463 89734ae f002e6a 89734ae f002e6a 4306537 be9a691 23c94ef 4306537 c902775 5075711 337d5f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 |
import gradio as gr
import torch
import numpy as np
import supervision as sv
from typing import Iterable
from transformers import (
Qwen3VLForConditionalGeneration,
Qwen3VLProcessor,
)
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
import json
import ast
import re
from PIL import Image
from spaces import GPU
colors.steel_blue = colors.Color(
name="steel_blue",
c50="#EBF3F8",
c100="#D3E5F0",
c200="#A8CCE1",
c300="#7DB3D2",
c400="#529AC3",
c500="#4682B4",
c600="#3E72A0",
c700="#36638C",
c800="#2E5378",
c900="#264364",
c950="#1E3450",
)
class SteelBlueTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.steel_blue,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
background_fill_primary_dark="*primary_900",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
button_secondary_text_color="black",
button_secondary_text_color_hover="white",
button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
slider_color="*secondary_500",
slider_color_dark="*secondary_600",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
button_large_padding="11px",
color_accent_soft="*primary_100",
block_label_background_fill="*primary_200",
)
steel_blue_theme = SteelBlueTheme()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = "auto"
CATEGORIES = ["Query", "Caption", "Point", "Detect"]
qwen_model = Qwen3VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-VL-4B-Instruct",
dtype=DTYPE,
device_map=DEVICE,
).eval()
qwen_processor = Qwen3VLProcessor.from_pretrained(
"Qwen/Qwen3-VL-4B-Instruct",
)
def safe_parse_json(text: str):
text = text.strip()
text = re.sub(r"^```(json)?", "", text)
text = re.sub(r"```$", "", text)
text = text.strip()
try:
return json.loads(text)
except json.JSONDecodeError:
pass
try:
return ast.literal_eval(text)
except Exception:
return {}
def annotate_image(image: Image.Image, result: dict):
if not isinstance(image, Image.Image) or not isinstance(result, dict):
return image
# Ensure image is mutable
image = image.convert("RGB")
original_width, original_height = image.size
if "points" in result and result["points"]:
points_list = [
[int(p["x"] * original_width), int(p["y"] * original_height)]
for p in result.get("points", [])
]
if not points_list:
return image
points_array = np.array(points_list).reshape(1, -1, 2)
key_points = sv.KeyPoints(xy=points_array)
vertex_annotator = sv.VertexAnnotator(radius=4, color=sv.Color.RED)
annotated_image = vertex_annotator.annotate(scene=np.array(image.copy()), key_points=key_points)
return Image.fromarray(annotated_image)
if "objects" in result and result["objects"]:
boxes = []
for obj in result["objects"]:
x_min = obj.get("x_min", 0.0) * original_width
y_min = obj.get("y_min", 0.0) * original_height
x_max = obj.get("x_max", 0.0) * original_width
y_max = obj.get("y_max", 0.0) * original_height
boxes.append([x_min, y_min, x_max, y_max])
if not boxes:
return image
detections = sv.Detections(xyxy=np.array(boxes))
if len(detections) == 0:
return image
box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=2)
annotated_image = box_annotator.annotate(scene=np.array(image.copy()), detections=detections)
return Image.fromarray(annotated_image)
return image
def run_qwen_inference(image: Image.Image, prompt: str):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
inputs = qwen_processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
).to(DEVICE)
with torch.inference_mode():
generated_ids = qwen_model.generate(
**inputs,
max_new_tokens=512,
)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return qwen_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
@GPU
def process_qwen(image: Image.Image, category: str, prompt: str):
if category == "Query":
return run_qwen_inference(image, prompt), {}
elif category == "Caption":
full_prompt = f"Provide a {prompt} length caption for the image."
return run_qwen_inference(image, full_prompt), {}
elif category == "Point":
full_prompt = (
f"Provide 2d point coordinates for {prompt}. Report in JSON format."
)
output_text = run_qwen_inference(image, full_prompt)
parsed_json = safe_parse_json(output_text)
points_result = {"points": []}
if isinstance(parsed_json, list):
for item in parsed_json:
if "point_2d" in item and len(item["point_2d"]) == 2:
x, y = item["point_2d"]
points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0})
return json.dumps(points_result, indent=2), points_result
elif category == "Detect":
full_prompt = (
f"Provide bounding box coordinates for {prompt}. Report in JSON format."
)
output_text = run_qwen_inference(image, full_prompt)
parsed_json = safe_parse_json(output_text)
objects_result = {"objects": []}
if isinstance(parsed_json, list):
for item in parsed_json:
if "bbox_2d" in item and len(item["bbox_2d"]) == 4:
xmin, ymin, xmax, ymax = item["bbox_2d"]
objects_result["objects"].append(
{
"x_min": xmin / 1000.0,
"y_min": ymin / 1000.0,
"x_max": xmax / 1000.0,
"y_max": ymax / 1000.0,
}
)
return json.dumps(objects_result, indent=2), objects_result
return "Invalid category", {}
def process_inputs(image, category, prompt):
if image is None:
raise gr.Error("Please upload an image.")
if not prompt:
raise gr.Error("Please provide a prompt.")
image.thumbnail((512, 512))
qwen_text, qwen_data = process_qwen(image, category, prompt)
qwen_annotated_image = annotate_image(image.copy(), qwen_data)
return qwen_annotated_image, qwen_text
def on_category_change(category: str):
if category == "Query":
return gr.Textbox(placeholder="e.g., Count the total number of boats and describe the environment.")
elif category == "Caption":
return gr.Textbox(placeholder="e.g., short, normal, detailed")
elif category == "Point":
return gr.Textbox(placeholder="e.g., The gun held by the person.")
elif category == "Detect":
return gr.Textbox(placeholder="e.g., The headlight of the car.")
return gr.Textbox(placeholder="e.g., detect the object.")
css = """
#main-title h1 {
font-size: 2.3em !important;
}
#output-title h2 {
font-size: 2.1em !important;
}
"""
with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# **Qwen-3VL: Multimodal Understanding**", elem_id="main-title")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Image")
category_select = gr.Radio(
choices=CATEGORIES,
value="Query",
label="Select Task Category",
interactive=True,
)
prompt_input = gr.Textbox(
placeholder="e.g., Count the total number of boats and describe the environment.",
label="Prompt",
lines=1,
)
submit_btn = gr.Button("Process Image", variant="primary")
with gr.Column(scale=2):
qwen_img_output = gr.Image(label="Output Image")
qwen_text_output = gr.Textbox(
label="Text Output", lines=10, interactive=False, show_copy_button=True
)
gr.Examples(
examples=[
["examples/5.jpg", "Point", "Detect the children who are out of focus and wearing a white T-shirt."],
["examples/5.jpg", "Detect", "Point out the out-of-focus (all) children."],
["examples/4.jpg", "Detect", "Headlight"],
["examples/3.jpg", "Point", "Gun"],
["examples/1.jpg", "Query", "Count the total number of boats and describe the environment."],
["examples/2.jpg", "Caption", "a brief"],
],
inputs=[image_input, category_select, prompt_input],
)
category_select.change(
fn=on_category_change,
inputs=[category_select],
outputs=[prompt_input],
)
submit_btn.click(
fn=process_inputs,
inputs=[image_input, category_select, prompt_input],
outputs=[qwen_img_output, qwen_text_output],
)
if __name__ == "__main__":
demo.launch(mcp_server=True, ssr_mode=False, show_error=True) |