echarlaix's picture
echarlaix HF Staff
fix id model
407114c
import os
import pathlib
import tempfile
from collections.abc import Iterator
from threading import Thread
import av
import gradio as gr
import torch
from gradio.utils import get_upload_folder
from transformers import AutoModelForImageTextToText, AutoProcessor
from transformers.generation.streamers import TextIteratorStreamer
from optimum.intel import OVModelForVisualCausalLM
default_model_id = "echarlaix/SmolVLM2-500M-Video-Instruct-openvino"
model_cache = {
"model_id" : default_model_id,
"processor" : AutoProcessor.from_pretrained(default_model_id),
"model" : OVModelForVisualCausalLM.from_pretrained(default_model_id),
}
def update_model(model_id):
if model_cache["model_id"] != model_id:
model_cache["model_id"] = model_id
model_cache["processor"] = AutoProcessor.from_pretrained(model_id)
model_cache["model"] = OVModelForVisualCausalLM.from_pretrained(model_id)
IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
VIDEO_FILE_TYPES = (".mp4", ".mov", ".webm")
GRADIO_TEMP_DIR = get_upload_folder()
TARGET_FPS = int(os.getenv("TARGET_FPS", "3"))
MAX_FRAMES = int(os.getenv("MAX_FRAMES", "30"))
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
def get_file_type(path: str) -> str:
if path.endswith(IMAGE_FILE_TYPES):
return "image"
if path.endswith(VIDEO_FILE_TYPES):
return "video"
error_message = f"Unsupported file type: {path}"
raise ValueError(error_message)
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
video_count = 0
non_video_count = 0
for path in paths:
if path.endswith(VIDEO_FILE_TYPES):
video_count += 1
else:
non_video_count += 1
return video_count, non_video_count
def validate_media_constraints(message: dict) -> bool:
video_count, non_video_count = count_files_in_new_message(message["files"])
if video_count > 1:
gr.Warning("Only one video is supported.")
return False
if video_count == 1 and non_video_count > 0:
gr.Warning("Mixing images and videos is not allowed.")
return False
return True
def extract_frames_to_tempdir(
video_path: str,
target_fps: float,
max_frames: int | None = None,
parent_dir: str | None = None,
prefix: str = "frames_",
) -> str:
temp_dir = tempfile.mkdtemp(prefix=prefix, dir=parent_dir)
container = av.open(video_path)
video_stream = container.streams.video[0]
if video_stream.duration is None or video_stream.time_base is None:
raise ValueError("video_stream is missing duration or time_base")
time_base = video_stream.time_base
duration = float(video_stream.duration * time_base)
interval = 1.0 / target_fps
total_frames = int(duration * target_fps)
if max_frames is not None:
total_frames = min(total_frames, max_frames)
target_times = [i * interval for i in range(total_frames)]
target_index = 0
for frame in container.decode(video=0):
if frame.pts is None:
continue
timestamp = float(frame.pts * time_base)
if target_index < len(target_times) and abs(timestamp - target_times[target_index]) < (interval / 2):
frame_path = pathlib.Path(temp_dir) / f"frame_{target_index:04d}.jpg"
frame.to_image().save(frame_path)
target_index += 1
if max_frames is not None and target_index >= max_frames:
break
container.close()
return temp_dir
def process_new_user_message(message: dict) -> list[dict]:
if not message["files"]:
return [{"type": "text", "text": message["text"]}]
file_types = [get_file_type(path) for path in message["files"]]
if len(file_types) == 1 and file_types[0] == "video":
gr.Info(f"Video will be processed at {TARGET_FPS} FPS, max {MAX_FRAMES} frames in this Space.")
temp_dir = extract_frames_to_tempdir(
message["files"][0],
target_fps=TARGET_FPS,
max_frames=MAX_FRAMES,
parent_dir=GRADIO_TEMP_DIR,
)
paths = sorted(pathlib.Path(temp_dir).glob("*.jpg"))
return [
{"type": "text", "text": message["text"]},
*[{"type": "image", "image": path.as_posix()} for path in paths],
]
return [
{"type": "text", "text": message["text"]},
*[{"type": file_type, file_type: path} for path, file_type in zip(message["files"], file_types, strict=True)],
]
def process_history(history: list[dict]) -> list[dict]:
messages = []
current_user_content: list[dict] = []
for item in history:
if item["role"] == "assistant":
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
current_user_content = []
messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
else:
content = item["content"]
if isinstance(content, str):
current_user_content.append({"type": "text", "text": content})
else:
filepath = content[0]
file_type = get_file_type(filepath)
current_user_content.append({"type": file_type, file_type: filepath})
return messages
@torch.inference_mode()
def generate(message: dict, history: list[dict], model_id: str, max_new_tokens: int = 512) -> Iterator[str]:
system_prompt = "You are a helpful assistant."
update_model(model_id)
processor = model_cache["processor"]
model = model_cache["model"]
if not validate_media_constraints(message):
yield ""
return
messages = []
if system_prompt:
messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
messages.extend(process_history(history))
messages.append({"role": "user", "content": process_new_user_message(message)})
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
n_tokens = inputs["input_ids"].shape[1]
if n_tokens > MAX_INPUT_TOKENS:
gr.Warning(
f"Input too long. Max {MAX_INPUT_TOKENS} tokens. Got {n_tokens} tokens. This limit is set to avoid out-of-memory errors in this Space."
)
yield ""
return
# inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=False,
disable_compile=True,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
output = ""
for delta in streamer:
output += delta
yield output
examples = [
[
{
"text": "What is on the flower?",
"files": ["assets/bee.jpg"],
}
],
[
{
"text": "Describe this image in detail.",
"files": ["assets/dogs.jpg"],
}
],
[
{
"text": "Give me a short and easy recipe for this dish",
"files": ["assets/recipe_burger.webp"],
}
],
[
{
"text": "I want to go somewhere similar to the one in the photo. Give me destinations and travel tips",
"files": ["assets/travel_tips.jpg"],
}
],
[
{
"text": "As an art critic AI assistant, could you describe this painting in details and make a thorough critic?",
"files": ["assets/art_critic.png"],
}
],
[
{
"text": "What is the capital of France?",
"files": [],
}
],
]
model_choices = [
"echarlaix/SmolVLM2-500M-Video-Instruct-openvino",
"echarlaix/SmolVLM2-500M-Video-Instruct-openvino-8bit-static",
"echarlaix/SmolVLM2-500M-Video-Instruct-openvino-8bit-woq",
]
demo = gr.ChatInterface(
fn=generate,
type="messages",
textbox=gr.MultimodalTextbox(
file_types=list(IMAGE_FILE_TYPES + VIDEO_FILE_TYPES),
file_count="multiple",
autofocus=True,
),
multimodal=True,
additional_inputs=[
gr.Dropdown(model_choices, value=model_choices[0], label="Model ID"),
# gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
],
stop_btn=False,
title="Fast quantized SmolVLM2 ⚡",
description="Play with a [SmolVLM2-500M-Video-Instruct](https://huggingface.co/echarlaix/SmolVLM2-500M-Video-Instruct-openvino) and its quantized variants : [SmolVLM2-500M-Video-Instruct-openvino-8bit-woq](https://huggingface.co/echarlaix/SmolVLM2-500M-Video-Instruct-openvino-8bit-woq) and [SmolVLM2-500M-Video-Instruct-openvino-8bit-static](https://huggingface.co/echarlaix/SmolVLM2-500M-Video-Instruct-openvino-8bit-static) both obtained by respectively applying Weight-Only Quantization and Static Quantization using [Optimum Intel](https://github.com/huggingface/optimum-intel) NNCF integration. To get started, upload an image and text or try one of the examples. This demo runs on 4th Generation Intel Xeon (Sapphire Rapids) processors.",
examples=examples,
run_examples_on_click=False,
cache_examples=False,
css_paths="style.css",
delete_cache=(1800, 1800),
)
if __name__ == "__main__":
demo.launch()