sming256's picture
Update app.py
22aa0b3 verified
"""
VideoAuto-R1 (Qwen3-VL) Demo
A Gradio-based chat interface for adaptive inference with image/video inputs.
"""
import spaces
import os
import base64
from io import BytesIO
import torch
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, AutoTokenizer
from videoauto_r1.qwen_vl_utils.vision_process import process_vision_info
from videoauto_r1.modeling_qwen3_vl_patched import Qwen3VLForConditionalGeneration
from videoauto_r1.early_exit import compute_first_boxed_answer_probs
# ============================================================================
# Constants
# ============================================================================
COT_SYSTEM_PROMPT_ANSWER_TWICE = (
"You are a helpful assistant.\n"
"FIRST: Output your initial answer inside the first \\boxed{...} without any analysis or explanations. "
"If you cannot determine the answer without reasoning, output \\boxed{Let's analyze the problem step by step.} instead.\n"
"THEN: Think through the reasoning as an internal monologue enclosed within <think>...</think>.\n"
"AT LAST: Output the final answer again inside \\boxed{...}. If you believe the previous answer was correct, repeat it; otherwise, correct it.\n"
"Output format: \\boxed{...}<think>...</think>\\boxed{...}"
)
VIDEO_EXTS = (".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm")
IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".tiff")
CUSTOM_CSS = """
#chatbot .message[class*="user"] {
max-width: 50% !important;
}
#chatbot .message[class*="bot"],
#chatbot .message[class*="assistant"] {
max-width: 60% !important;
}
#chatbot .message > div {
width: 100% !important;
max-width: 100% !important;
}
"""
MODEL_PATH = "IVUL-KAUST/VideoAuto-R1-Qwen3-VL-8B"
# ============================================================================
# Global Model Variables
# ============================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = (
Qwen3VLForConditionalGeneration.from_pretrained(
MODEL_PATH,
dtype="bfloat16",
attn_implementation="sdpa",
)
.to("cuda")
.eval()
)
processor = AutoProcessor.from_pretrained(MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
# ============================================================================
# Utility Functions
# ============================================================================
def detect_media_type(file_path: str | None) -> str | None:
"""
Detect media type from file extension.
Args:
file_path: Path to the media file
Returns:
'image', 'video', or None
"""
if not file_path:
return None
p = file_path.lower()
if p.endswith(VIDEO_EXTS):
return "video"
if p.endswith(IMAGE_EXTS):
return "image"
# Fallback: try to open as image
try:
Image.open(file_path)
return "image"
except Exception:
return "video"
def process_image(
image_path: str,
image_min_pixels: int = 128 * 28 * 28,
image_max_pixels: int = 16384 * 28 * 28,
) -> dict | None:
"""
Process image file to base64 format.
Args:
image_path: Path to image file
image_min_pixels: Minimum pixel count
image_max_pixels: Maximum pixel count
Returns:
Dictionary with image data or None
"""
if image_path is None:
return None
image = Image.open(image_path).convert("RGB")
buffer = BytesIO()
image.save(buffer, format="JPEG")
base64_bytes = base64.b64encode(buffer.getvalue())
base64_string = base64_bytes.decode("utf-8")
return {
"type": "image",
"image": f"data:image/jpeg;base64,{base64_string}",
"min_pixels": image_min_pixels,
"max_pixels": image_max_pixels,
}
def process_video(
video_path: str,
video_min_pixels: int = 16 * 28 * 28,
video_max_pixels: int = 768 * 28 * 28,
video_total_pixels: int = 128000 * 28 * 28,
min_frames: int = 4,
max_frames: int = 64,
fps: float = 2.0,
) -> dict | None:
"""
Process video file configuration.
Args:
video_path: Path to video file
video_min_pixels: Minimum pixels per frame
video_max_pixels: Maximum pixels per frame
video_total_pixels: Total pixels across all frames
min_frames: Minimum number of frames
max_frames: Maximum number of frames
fps: Frames per second for sampling
Returns:
Dictionary with video configuration or None
"""
if video_path is None:
return None
return {
"type": "video",
"video": video_path,
"min_pixels": video_min_pixels,
"max_pixels": video_max_pixels,
"total_pixels": video_total_pixels,
"min_frames": min_frames,
"max_frames": max_frames,
"fps": fps,
}
@spaces.GPU(duration=180)
def generate(
media_input: str | None,
prompt: str,
early_exit_thresh: float,
temperature: float,
max_new_tokens: int = 4096,
) -> dict:
"""
Generate response with adaptive inference.
Args:
media_input: Path to media file
prompt: Text prompt
early_exit_thresh: Confidence threshold for early exit
temperature: Sampling temperature
max_new_tokens: Maximum tokens to generate
Returns:
Dictionary containing response and metadata
"""
# Prepare message
message = [{"role": "system", "content": COT_SYSTEM_PROMPT_ANSWER_TWICE}]
content_parts = []
# Process media input
if media_input is not None:
media_type = detect_media_type(media_input)
if media_type == "video":
video_dict = process_video(media_input)
if video_dict:
content_parts.append(video_dict)
elif media_type == "image":
image_dict = process_image(media_input)
if image_dict:
content_parts.append(image_dict)
# Add text prompt
content_parts.append({"type": "text", "text": prompt})
message.append({"role": "user", "content": content_parts})
# Apply chat template
text = processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
# Process vision inputs
image_inputs, video_inputs, video_kwargs = process_vision_info(
[message],
image_patch_size=16,
return_video_kwargs=True,
return_video_metadata=True,
)
if video_inputs is not None:
video_inputs, video_metadatas = zip(*video_inputs)
video_inputs = list(video_inputs)
video_metadatas = list(video_metadatas)
else:
video_metadatas = None
# Prepare inputs
inputs = processor(
text=text,
images=image_inputs,
videos=video_inputs,
video_metadata=video_metadatas,
do_resize=False,
padding=True,
return_tensors="pt",
**video_kwargs,
)
inputs = inputs.to(device)
# Generation configuration
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"temperature": temperature if temperature > 0 else None,
"do_sample": temperature > 0,
"top_p": 0.9 if temperature > 0 else None,
"num_beams": 1,
"use_cache": True,
"return_dict_in_generate": True,
"output_scores": True,
}
# Generate response
with torch.no_grad():
gen_out = model.generate(
**inputs,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
**gen_kwargs,
)
# Decode output
generated_ids = gen_out.sequences[0][len(inputs.input_ids[0]) :]
answer = processor.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
# Compute confidence
first_box_probs = compute_first_boxed_answer_probs(
b=0,
gen_ids=generated_ids,
gen_out=gen_out,
ans=answer,
task="",
tokenizer=tokenizer,
)
# Parse response
first_answer = answer.split("<think>")[0]
second_answer = answer.split("</think>")[-1] if "</think>" in answer else first_answer
reasoning = answer.split("<think>")[-1].split("</think>")[0] if "<think>" in answer else "N/A"
# Determine inference mode
if first_box_probs >= early_exit_thresh:
need_cot = False
reasoning = False
else:
need_cot = True
return {
"full_response": answer,
"first_answer": first_answer,
"confidence": f"{first_box_probs:.4f}",
"need_cot": need_cot,
"reasoning": reasoning,
"second_answer": second_answer,
}
# ============================================================================
# Gradio Callback Functions
# ============================================================================
def update_preview(file_path: str | None):
"""Update preview widgets based on media type."""
mtype = detect_media_type(file_path)
if mtype == "image":
return (
gr.update(value=file_path, visible=True), # image_preview
gr.update(value=None, visible=False), # video_preview
)
elif mtype == "video":
return (
gr.update(value=None, visible=False), # image_preview
gr.update(value=file_path, visible=True), # video_preview
)
else:
return (
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
)
def chat_generate(
media_path,
user_text,
messages_state,
chatbot_state,
last_media_state,
early_exit_thresh,
temperature,
):
"""Handle chat message generation."""
if user_text is None or str(user_text).strip() == "":
raise gr.Error("Chat message cannot be empty.")
# Clear history if media changed
if (
(media_path is not None)
and (last_media_state is not None)
and (os.path.basename(media_path) != os.path.basename(last_media_state))
):
messages_state = []
chatbot_state = []
# Initialize system prompt
if len(messages_state) == 0:
messages_state.append({"role": "system", "content": COT_SYSTEM_PROMPT_ANSWER_TWICE})
# Prepare user message
content_parts = []
if media_path is not None:
mtype = detect_media_type(media_path)
if mtype == "video":
vd = process_video(media_path)
if vd:
content_parts.append(vd)
elif mtype == "image":
imd = process_image(media_path)
if imd:
content_parts.append(imd)
content_parts.append({"type": "text", "text": user_text})
messages_state.append({"role": "user", "content": content_parts})
# Generate response
result = generate(media_path, user_text, early_exit_thresh, temperature)
# Format assistant response
first_ans = (result.get("first_answer") or "").strip()
conf = result.get("confidence", "N/A")
need_cot = result.get("need_cot", "")
reasoning = result.get("reasoning", "")
final_ans = (result.get("second_answer") or "").strip()
if need_cot:
decision_prompt = f"Continue CoT Reasoning (confidence = {conf})"
else:
decision_prompt = f"Early Exit (confidence = {conf})"
assistant_display_1 = f"**Initial Answer:**\n{first_ans}\n\n" f"**{decision_prompt}**\n\n"
# Update state
messages_state.append({"role": "assistant", "content": assistant_display_1})
chatbot_state.append({"role": "user", "content": user_text})
chatbot_state.append({"role": "assistant", "content": assistant_display_1})
if need_cot:
assistant_display_2 = (
f"\n\n**<think>**\n\n{reasoning}\n**</think>**\n\n" f"**Reviewed Answer:**\n{final_ans}\n\n"
)
messages_state.append({"role": "assistant", "content": assistant_display_2})
chatbot_state.append({"role": "assistant", "content": assistant_display_2})
# Disable textbox and send button after generation to prevent interleaved conversation
return (
messages_state,
chatbot_state,
media_path,
gr.update(value="", interactive=False), # Disable and clear textbox
gr.update(interactive=False), # Disable send button
)
def clear_history():
"""Clear all chat history and reset interface."""
return (
[], # messages_state
[], # chatbot_state
None, # last_media_state
gr.update(value=None), # file
gr.update(value=None, visible=False), # image_preview
gr.update(value=None, visible=False), # video_preview
gr.update(value="", interactive=True), # Re-enable and clear textbox
gr.update(interactive=True), # Re-enable send button
)
# ============================================================================
# Example Data
# ============================================================================
EXAMPLES = [
[
"assets/yt--MAYaJ5cyOE_70.mp4",
"Question: Which one of these descriptions correctly matches the actions in the video?\nOptions:\n(A) officiating\n(B) skating\n(C) stopping\n(D) playing sports\nPut your final answer in \\boxed{}.",
# GT is B
],
[
"assets/validation_Finance_2.mp4",
"Using the Arbitrage Pricing Theory model shown above, calculate the expected return E(rp) if the risk-free rate increases to 5%. All other risk premiums (RP) and beta (\\beta) values remain unchanged.\nOptions:\nA. 13.4%\nB. 14.8%\nC. 15.6%\nD. 16.1%\nE. 16.5%\nF. 16.9%\nG. 17.5%\nH. 17.8%\nI. 17.2%\nJ. 18.1%\nPut your final answer in \\boxed{}.",
# GT is I
],
[
"assets/M3CoT-25169-0.png",
"Within the image, you'll notice several purchased items. And we assume that the water temperature is 4 ° C at this time.\nWithin the image, can you identify the count of items among the provided options that will go below the waterline?\nA. 0\nB. 1\nC. 2\nD. 3\nPut your final answer in \\boxed{}.",
# GT is B
],
[
None,
"Determine the value of the parameter $m$ such that the equation $(m-2)x^2 + (m^2-4m+3)x - (6m^2-2) = 0$ has real solutions, and the sum of the cubes of these solutions is equal to zero.\nPut your final answer in \\boxed{}.",
# GT is 3
],
]
# ============================================================================
# Gradio Interface
# ============================================================================
demo = gr.Blocks(title="VideoAuto-R1 Demo")
with demo:
gr.Markdown("# [VideoAuto-R1 Demo](https://github.com/IVUL-KAUST/VideoAuto-R1/)")
# Display system prompt
with gr.Accordion("System Prompt", open=False):
gr.Markdown(f"```\n{COT_SYSTEM_PROMPT_ANSWER_TWICE}\n```")
# State variables
messages_state = gr.State([])
chatbot_state = gr.State([])
last_media_state = gr.State(None)
with gr.Row():
# Left column: Media input and settings
with gr.Column(scale=3):
media_input = gr.File(
label="Upload Image or Video",
file_types=["image", "video"],
type="filepath",
)
image_preview = gr.Image(label="Image Preview", visible=False)
video_preview = gr.Video(label="Video Preview", visible=False)
with gr.Accordion("Advanced Settings", open=True):
early_exit_thresh = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.98,
step=0.01,
label="Early Exit Threshold",
)
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.0,
step=0.1,
label="Temperature",
)
# Right column: Chat interface
with gr.Column(scale=7):
chatbot = gr.Chatbot(
label="Chat",
elem_id="chatbot",
height=600,
sanitize_html=False,
)
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press ENTER",
lines=2,
)
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear")
gr.Markdown("Please click the **Clear** button before starting a new conversation or trying a new example.")
# Event handlers
media_input.change(
fn=update_preview,
inputs=[media_input],
outputs=[image_preview, video_preview],
)
# Send button click: generate response and disable input controls
send_btn.click(
fn=chat_generate,
inputs=[
media_input,
textbox,
messages_state,
chatbot_state,
last_media_state,
early_exit_thresh,
temperature,
],
outputs=[messages_state, chatbot_state, last_media_state, textbox, send_btn],
).then(
fn=lambda cs: cs,
inputs=[chatbot_state],
outputs=[chatbot],
)
# Textbox submit: generate response and disable input controls
textbox.submit(
fn=chat_generate,
inputs=[
media_input,
textbox,
messages_state,
chatbot_state,
last_media_state,
early_exit_thresh,
temperature,
],
outputs=[messages_state, chatbot_state, last_media_state, textbox, send_btn],
).then(
fn=lambda cs: cs,
inputs=[chatbot_state],
outputs=[chatbot],
)
# Clear button: reset all states and re-enable input controls
clear_btn.click(
fn=clear_history,
inputs=[],
outputs=[
messages_state,
chatbot_state,
last_media_state,
media_input,
image_preview,
video_preview,
textbox,
send_btn,
],
).then(
fn=lambda cs: cs,
inputs=[chatbot_state],
outputs=[chatbot],
)
examples_ds = gr.Dataset(
components=[media_input, textbox],
samples=EXAMPLES,
label="Examples",
type="index", # important: pass selected row index to fn
)
def load_example(idx: int | None):
# idx can be None when deselecting
if idx is None:
# just clear everything
return clear_history()
media, text = EXAMPLES[idx][0], EXAMPLES[idx][1]
# 1) clear all states + re-enable inputs
ms, cs, last, file_u, img_u, vid_u, tb_u, send_u = clear_history()
# 2) set selected example values
file_u = gr.update(value=media)
tb_u = gr.update(value=text, interactive=True)
send_u = gr.update(interactive=True)
# 3) update preview explicitly (don't rely on File.change always firing)
img_u, vid_u = update_preview(media)
# 4) optionally set last_media_state to current media
last = media
return ms, cs, last, file_u, img_u, vid_u, tb_u, send_u
examples_ds.select(
fn=load_example,
inputs=[examples_ds],
outputs=[
messages_state,
chatbot_state,
last_media_state,
media_input,
image_preview,
video_preview,
textbox,
send_btn,
],
).then(
fn=lambda cs: cs,
inputs=[chatbot_state],
outputs=[chatbot],
)
# Launch demo
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
allowed_paths=["assets"],
debug=True,
css=CUSTOM_CSS,
)