Spaces:
Running
on
Zero
Running
on
Zero
import spaces, ffmpeg, os, sys, torch, time | |
import gradio as gr | |
from transformers import ( | |
Qwen2_5_VLForConditionalGeneration, | |
AutoModelForImageTextToText, | |
Gemma3nForConditionalGeneration, | |
AutoProcessor, | |
BitsAndBytesConfig, | |
) | |
from qwen_vl_utils import process_vision_info | |
from loguru import logger | |
logger.remove() | |
logger.add( | |
sys.stderr, | |
format="<d>{time:YYYY-MM-DD ddd HH:mm:ss}</d> | <lvl>{level}</lvl> | <lvl>{message}</lvl>", | |
) | |
# --- Installing Flash Attention for ZeroGPU is special --- # | |
import subprocess | |
subprocess.run( | |
"pip install flash-attn --no-build-isolation", | |
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
shell=True, | |
) | |
# --- now we got Flash Attention ---# | |
# Set target DEVICE and DTYPE | |
# For maximum memory efficiency, use bfloat16 if your GPU supports it, otherwise float16. | |
DTYPE = ( | |
torch.bfloat16 | |
if torch.cuda.is_available() and torch.cuda.is_bf16_supported() | |
else torch.float16 | |
) | |
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Use "auto" to let accelerate handle device placement (GPU, CPU, disk) | |
DEVICE = "auto" | |
logger.info(f"Device: {DEVICE}, dtype: {DTYPE}") | |
def get_fps_ffmpeg(video_path: str): | |
probe = ffmpeg.probe(video_path) | |
# Find the first video stream | |
video_stream = next( | |
(stream for stream in probe["streams"] if stream["codec_type"] == "video"), None | |
) | |
if video_stream is None: | |
raise ValueError("No video stream found") | |
# Frame rate is given as a string fraction, e.g., '30000/1001' | |
r_frame_rate = video_stream["r_frame_rate"] | |
num, denom = map(int, r_frame_rate.split("/")) | |
return num / denom | |
def load_model( | |
model_name: str = "chancharikm/qwen2.5-vl-7b-cam-motion-preview", | |
use_flash_attention: bool = True, | |
apply_quantization: bool = True, | |
): | |
# We recommend enabling flash_attention_2 for better acceleration and memory saving, | |
# especially in multi-image and video scenarios. | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, # Load model weights in 4-bit | |
bnb_4bit_quant_type="nf4", # Use NF4 quantization (or "fp4") | |
bnb_4bit_compute_dtype=DTYPE, # Perform computations in bfloat16/float16 | |
bnb_4bit_use_double_quant=True, # Optional: further quantization for slightly more memory saving | |
) | |
# Determine model family from model name | |
model_family = model_name.split("/")[-1].split("-")[ | |
0 | |
] # Extract model family from name | |
# Common model loading arguments | |
common_args = { | |
"torch_dtype": DTYPE, | |
"device_map": DEVICE, | |
"low_cpu_mem_usage": True, | |
"quantization_config": bnb_config if apply_quantization else None, | |
} | |
# Add flash attention if supported and requested | |
if use_flash_attention: | |
common_args["attn_implementation"] = "flash_attention_2" | |
# Load model based on family | |
match model_family: | |
case "qwen2.5" | "Qwen2.5": | |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
model_name, **common_args | |
) | |
case "InternVL3": | |
model = AutoModelForImageTextToText.from_pretrained( | |
model_name, **common_args | |
) | |
case "gemma": | |
model = Gemma3nForConditionalGeneration.from_pretrained( | |
model_name, **common_args | |
) | |
case _: | |
raise ValueError(f"Unsupported model family: {model_family}") | |
# Set model to evaluation mode for inference (disables dropout, etc.) | |
return model.eval() | |
def load_processor(model_name="Qwen/Qwen2.5-VL-7B-Instruct"): | |
return AutoProcessor.from_pretrained( | |
model_name, | |
device_map=DEVICE, | |
use_fast=True, | |
torch_dtype=DTYPE, | |
) | |
logger.debug("Loading Models and Processors...") | |
MODEL_ZOO = { | |
"qwen2.5-vl-7b-cam-motion-preview": load_model( | |
model_name="chancharikm/qwen2.5-vl-7b-cam-motion-preview", | |
use_flash_attention=False, | |
apply_quantization=False, | |
), | |
"qwen2.5-vl-7b-instruct": load_model( | |
model_name="Qwen/Qwen2.5-VL-7B-Instruct", | |
use_flash_attention=False, | |
apply_quantization=False, | |
), | |
"qwen2.5-vl-3b-instruct": load_model( | |
model_name="Qwen/Qwen2.5-VL-3B-Instruct", | |
use_flash_attention=False, | |
apply_quantization=False, | |
), | |
"InternVL3-1B-hf": load_model( | |
model_name="OpenGVLab/InternVL3-1B-hf", | |
use_flash_attention=False, | |
apply_quantization=False, | |
), | |
"InternVL3-2B-hf": load_model( | |
model_name="OpenGVLab/InternVL3-2B-hf", | |
use_flash_attention=False, | |
apply_quantization=False, | |
), | |
"InternVL3-8B-hf": load_model( | |
model_name="OpenGVLab/InternVL3-8B-hf", | |
use_flash_attention=False, | |
apply_quantization=True, | |
), | |
"gemma-3n-e4b-it": load_model( | |
model_name="google/gemma-3n-e4b-it", | |
use_flash_attention=False, | |
apply_quantization=True, | |
), | |
} | |
PROCESSORS = { | |
"qwen2.5-vl-7b-cam-motion-preview": load_processor("Qwen/Qwen2.5-VL-7B-Instruct"), | |
"qwen2.5-vl-7b-instruct": load_processor("Qwen/Qwen2.5-VL-7B-Instruct"), | |
"qwen2.5-vl-3b-instruct": load_processor("Qwen/Qwen2.5-VL-3B-Instruct"), | |
"InternVL3-1B-hf": load_processor("OpenGVLab/InternVL3-1B-hf"), | |
"InternVL3-2B-hf": load_processor("OpenGVLab/InternVL3-2B-hf"), | |
"InternVL3-8B-hf": load_processor("OpenGVLab/InternVL3-8B-hf"), | |
"gemma-3n-e4b-it": load_processor("google/gemma-3n-e4b-it"), | |
} | |
logger.debug("Models and Processors Loaded!") | |
def inference( | |
video_path: str, | |
prompt: str = "Describe the camera motion in this video.", | |
model_name: str = "qwen2.5-vl-7b-instruct", | |
custom_fps: int = 8, | |
max_tokens: int = 256, | |
temperature: float = 0.0, | |
): | |
s_time = time.time() | |
# default processor | |
# processor, model = PROCESSOR, MODEL | |
# processor = load_processor() | |
# model = load_model( | |
# use_flash_attention=use_flash_attention, apply_quantization=apply_quantization | |
# ) | |
model = MODEL_ZOO[model_name] | |
processor = PROCESSORS[model_name] | |
# The model is trained on 8.0 FPS which we recommend for optimal inference | |
fps = custom_fps if custom_fps else get_fps_ffmpeg(video_path) | |
logger.info(f"{os.path.basename(video_path)} FPS: {fps}") | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "video", | |
"video": video_path, | |
"fps": fps, | |
}, | |
{"type": "text", "text": prompt}, | |
], | |
} | |
] | |
# text = processor.apply_chat_template( | |
# messages, tokenize=False, add_generation_prompt=True | |
# ) | |
# image_inputs, video_inputs, video_kwargs = process_vision_info( | |
# messages, return_video_kwargs=True | |
# ) | |
# This prevents PyTorch from building the computation graph for gradients, | |
# saving a significant amount of memory for intermediate activations. | |
with torch.no_grad(): | |
model_family = model_name.split("-")[0] | |
match model_family: | |
case "qwen2.5": | |
text = processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
image_inputs, video_inputs, video_kwargs = process_vision_info( | |
messages, return_video_kwargs=True | |
) | |
inputs = processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
# fps=fps, | |
padding=True, | |
return_tensors="pt", | |
**video_kwargs, | |
) | |
inputs = inputs.to("cuda") | |
# Inference | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
temperature=float(temperature), | |
do_sample=temperature > 0.0, | |
) | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids) :] | |
for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
] | |
output_text = processor.batch_decode( | |
generated_ids_trimmed, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False, | |
)[0] | |
case "InternVL3" | "gemma": | |
inputs = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
fps=fps, | |
# num_frames = 8 | |
).to("cuda", dtype=DTYPE) | |
output = model.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
temperature=float(temperature), | |
do_sample=temperature > 0.0, | |
) | |
output_text = processor.decode( | |
output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True | |
) | |
case _: | |
raise ValueError(f"{model_name} is not currently supported") | |
return { | |
"output_text": output_text, | |
"fps": fps, | |
"inference_time": time.time() - s_time, | |
} | |
demo = gr.Interface( | |
fn=inference, | |
inputs=[ | |
gr.Video(label="Input Video"), | |
gr.Textbox( | |
label="Prompt", | |
lines=3, | |
info="Some models like [cam motion](https://huggingface.co/chancharikm/qwen2.5-vl-7b-cam-motion-preview) are trained specific prompts", | |
value="Describe the camera motion in this video.", | |
), | |
gr.Dropdown(label="Model", choices=list(MODEL_ZOO.keys())), | |
gr.Number( | |
label="FPS", | |
info="inference sampling rate (Qwen2.5VL is trained on videos with 8 fps); a value of 0 means the FPS of the input video will be used", | |
value=8, | |
minimum=0, | |
step=1, | |
), | |
gr.Slider( | |
label="Max Tokens", | |
info="maximum number of tokens to generate", | |
value=128, | |
minimum=32, | |
maximum=512, | |
step=32, | |
), | |
gr.Slider( | |
label="Temperature", | |
value=0.0, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
), | |
# gr.Checkbox(label="Use Flash Attention", value=False), | |
# gr.Checkbox(label="Apply Quantization", value=True), | |
], | |
outputs=gr.JSON(label="Output JSON"), | |
title="Video Captioning with VLM", | |
description='comparing various "small" VLMs on the task of video captioning', | |
api_name="video_inference", | |
) | |
demo.launch( | |
mcp_server=True, app_kwargs={"docs_url": "/docs"} # add FastAPI Swagger API Docs | |
) | |