import os import pathlib import tempfile from collections.abc import Iterator from threading import Thread import av import gradio as gr import torch from gradio.utils import get_upload_folder from transformers import AutoModelForImageTextToText, AutoProcessor from transformers.generation.streamers import TextIteratorStreamer from optimum.intel import OVModelForVisualCausalLM default_model_id = "echarlaix/SmolVLM2-500M-Video-Instruct-openvino" model_cache = { "model_id" : default_model_id, "processor" : AutoProcessor.from_pretrained(default_model_id), "model" : OVModelForVisualCausalLM.from_pretrained(default_model_id), } def update_model(model_id): if model_cache["model_id"] != model_id: model_cache["model_id"] = model_id model_cache["processor"] = AutoProcessor.from_pretrained(model_id) model_cache["model"] = OVModelForVisualCausalLM.from_pretrained(model_id) IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp") VIDEO_FILE_TYPES = (".mp4", ".mov", ".webm") GRADIO_TEMP_DIR = get_upload_folder() TARGET_FPS = int(os.getenv("TARGET_FPS", "3")) MAX_FRAMES = int(os.getenv("MAX_FRAMES", "30")) MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000")) def get_file_type(path: str) -> str: if path.endswith(IMAGE_FILE_TYPES): return "image" if path.endswith(VIDEO_FILE_TYPES): return "video" error_message = f"Unsupported file type: {path}" raise ValueError(error_message) def count_files_in_new_message(paths: list[str]) -> tuple[int, int]: video_count = 0 non_video_count = 0 for path in paths: if path.endswith(VIDEO_FILE_TYPES): video_count += 1 else: non_video_count += 1 return video_count, non_video_count def validate_media_constraints(message: dict) -> bool: video_count, non_video_count = count_files_in_new_message(message["files"]) if video_count > 1: gr.Warning("Only one video is supported.") return False if video_count == 1 and non_video_count > 0: gr.Warning("Mixing images and videos is not allowed.") return False return True def extract_frames_to_tempdir( video_path: str, target_fps: float, max_frames: int | None = None, parent_dir: str | None = None, prefix: str = "frames_", ) -> str: temp_dir = tempfile.mkdtemp(prefix=prefix, dir=parent_dir) container = av.open(video_path) video_stream = container.streams.video[0] if video_stream.duration is None or video_stream.time_base is None: raise ValueError("video_stream is missing duration or time_base") time_base = video_stream.time_base duration = float(video_stream.duration * time_base) interval = 1.0 / target_fps total_frames = int(duration * target_fps) if max_frames is not None: total_frames = min(total_frames, max_frames) target_times = [i * interval for i in range(total_frames)] target_index = 0 for frame in container.decode(video=0): if frame.pts is None: continue timestamp = float(frame.pts * time_base) if target_index < len(target_times) and abs(timestamp - target_times[target_index]) < (interval / 2): frame_path = pathlib.Path(temp_dir) / f"frame_{target_index:04d}.jpg" frame.to_image().save(frame_path) target_index += 1 if max_frames is not None and target_index >= max_frames: break container.close() return temp_dir def process_new_user_message(message: dict) -> list[dict]: if not message["files"]: return [{"type": "text", "text": message["text"]}] file_types = [get_file_type(path) for path in message["files"]] if len(file_types) == 1 and file_types[0] == "video": gr.Info(f"Video will be processed at {TARGET_FPS} FPS, max {MAX_FRAMES} frames in this Space.") temp_dir = extract_frames_to_tempdir( message["files"][0], target_fps=TARGET_FPS, max_frames=MAX_FRAMES, parent_dir=GRADIO_TEMP_DIR, ) paths = sorted(pathlib.Path(temp_dir).glob("*.jpg")) return [ {"type": "text", "text": message["text"]}, *[{"type": "image", "image": path.as_posix()} for path in paths], ] return [ {"type": "text", "text": message["text"]}, *[{"type": file_type, file_type: path} for path, file_type in zip(message["files"], file_types, strict=True)], ] def process_history(history: list[dict]) -> list[dict]: messages = [] current_user_content: list[dict] = [] for item in history: if item["role"] == "assistant": if current_user_content: messages.append({"role": "user", "content": current_user_content}) current_user_content = [] messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]}) else: content = item["content"] if isinstance(content, str): current_user_content.append({"type": "text", "text": content}) else: filepath = content[0] file_type = get_file_type(filepath) current_user_content.append({"type": file_type, file_type: filepath}) return messages @torch.inference_mode() def generate(message: dict, history: list[dict], model_id: str, max_new_tokens: int = 512) -> Iterator[str]: system_prompt = "You are a helpful assistant." update_model(model_id) processor = model_cache["processor"] model = model_cache["model"] if not validate_media_constraints(message): yield "" return messages = [] if system_prompt: messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) messages.extend(process_history(history)) messages.append({"role": "user", "content": process_new_user_message(message)}) inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ) n_tokens = inputs["input_ids"].shape[1] if n_tokens > MAX_INPUT_TOKENS: gr.Warning( f"Input too long. Max {MAX_INPUT_TOKENS} tokens. Got {n_tokens} tokens. This limit is set to avoid out-of-memory errors in this Space." ) yield "" return # inputs = inputs.to(device=model.device, dtype=torch.bfloat16) streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=False, disable_compile=True, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() output = "" for delta in streamer: output += delta yield output examples = [ [ { "text": "What is on the flower?", "files": ["assets/bee.jpg"], } ], [ { "text": "Describe this image in detail.", "files": ["assets/dogs.jpg"], } ], [ { "text": "Give me a short and easy recipe for this dish", "files": ["assets/recipe_burger.webp"], } ], [ { "text": "I want to go somewhere similar to the one in the photo. Give me destinations and travel tips", "files": ["assets/travel_tips.jpg"], } ], [ { "text": "As an art critic AI assistant, could you describe this painting in details and make a thorough critic?", "files": ["assets/art_critic.png"], } ], [ { "text": "What is the capital of France?", "files": [], } ], ] model_choices = [ "echarlaix/SmolVLM2-500M-Video-Instruct-openvino", "echarlaix/SmolVLM2-500M-Video-Instruct-openvino-8bit-static", "echarlaix/SmolVLM2-500M-Video-Instruct-openvino-8bit-woq", ] demo = gr.ChatInterface( fn=generate, type="messages", textbox=gr.MultimodalTextbox( file_types=list(IMAGE_FILE_TYPES + VIDEO_FILE_TYPES), file_count="multiple", autofocus=True, ), multimodal=True, additional_inputs=[ gr.Dropdown(model_choices, value=model_choices[0], label="Model ID"), # gr.Textbox(label="System Prompt", value="You are a helpful assistant."), gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700), ], stop_btn=False, title="Fast quantized SmolVLM2 ⚡", description="Play with a [SmolVLM2-500M-Video-Instruct](https://huggingface.co/echarlaix/SmolVLM2-500M-Video-Instruct-openvino) and its quantized variants : [SmolVLM2-500M-Video-Instruct-openvino-8bit-woq](https://huggingface.co/echarlaix/SmolVLM2-500M-Video-Instruct-openvino-8bit-woq) and [SmolVLM2-500M-Video-Instruct-openvino-8bit-static](https://huggingface.co/echarlaix/SmolVLM2-500M-Video-Instruct-openvino-8bit-static) both obtained by respectively applying Weight-Only Quantization and Static Quantization using [Optimum Intel](https://github.com/huggingface/optimum-intel) NNCF integration. To get started, upload an image and text or try one of the examples. This demo runs on 4th Generation Intel Xeon (Sapphire Rapids) processors.", examples=examples, run_examples_on_click=False, cache_examples=False, css_paths="style.css", delete_cache=(1800, 1800), ) if __name__ == "__main__": demo.launch()