SAM3-Video / app.py
linoyts's picture
linoyts HF Staff
Update app.py
c7dac2c verified
import os
import cv2
import tempfile
import spaces
import gradio as gr
import numpy as np
import torch
from PIL import Image
from transformers import Sam3VideoModel, Sam3VideoProcessor
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print("Loading SAM3 Video Model...")
VID_MODEL = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("facebook/sam3")
print("Model loaded!")
OUTPUT_FPS = 24
def apply_green_mask(base_image, mask_data, opacity=0.5):
"""Draw green mask overlay on a frame."""
if isinstance(base_image, np.ndarray):
base_image = Image.fromarray(base_image)
base_image = base_image.convert("RGBA")
if mask_data is None or len(mask_data) == 0:
return base_image.convert("RGB")
if isinstance(mask_data, torch.Tensor):
mask_data = mask_data.cpu().numpy()
mask_data = mask_data.astype(np.uint8)
if mask_data.ndim == 4:
mask_data = mask_data[0]
if mask_data.ndim == 3 and mask_data.shape[0] == 1:
mask_data = mask_data[0]
if mask_data.ndim == 3:
# Multiple masks — merge into one
mask_data = np.any(mask_data > 0, axis=0).astype(np.uint8)
green = (0, 255, 0)
mask_bitmap = Image.fromarray((mask_data * 255).astype(np.uint8))
if mask_bitmap.size != base_image.size:
mask_bitmap = mask_bitmap.resize(base_image.size, resample=Image.NEAREST)
color_fill = Image.new("RGBA", base_image.size, green + (0,))
mask_alpha = mask_bitmap.point(lambda v: int(v * opacity) if v > 0 else 0)
color_fill.putalpha(mask_alpha)
return Image.alpha_composite(base_image, color_fill).convert("RGB")
def get_video_info(video_path):
"""Return frame count and fps of the input video."""
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS) or 24
cap.release()
duration = total_frames / fps
return total_frames, fps, duration
def calc_timeout(source_vid, text_query):
if not source_vid:
return 60
_, _, duration = get_video_info(source_vid)
# ~2s processing per second of video, with a floor/ceiling
return max(60, min(int(duration * 3) + 30, 300))
@spaces.GPU(duration=calc_timeout)
def run_video_segmentation(source_vid, text_query):
if VID_MODEL is None or VID_PROCESSOR is None:
raise gr.Error("Video model failed to load.")
if not source_vid or not text_query:
raise gr.Error("Please provide both a video and a text prompt.")
try:
cap = cv2.VideoCapture(source_vid)
src_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
src_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
src_fps = cap.get(cv2.CAP_PROP_FPS) or 24
video_frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
total_frames = len(video_frames)
duration = total_frames / src_fps
status = f"Loaded {total_frames} frames ({duration:.1f}s @ {src_fps:.0f}fps). Processing..."
print(status)
session = VID_PROCESSOR.init_video_session(
video=video_frames, inference_device=device, dtype=torch.bfloat16
)
session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=text_query)
temp_out = tempfile.mktemp(suffix=".mp4")
writer = cv2.VideoWriter(temp_out, cv2.VideoWriter_fourcc(*"mp4v"), OUTPUT_FPS, (src_w, src_h))
for model_out in VID_MODEL.propagate_in_video_iterator(
inference_session=session, max_frame_num_to_track=total_frames
):
post = VID_PROCESSOR.postprocess_outputs(session, model_out)
f_idx = model_out.frame_idx
original = Image.fromarray(video_frames[f_idx])
if "masks" in post:
masks = post["masks"]
if masks.ndim == 4:
masks = masks.squeeze(1)
frame_out = apply_green_mask(original, masks)
else:
frame_out = original
writer.write(cv2.cvtColor(np.array(frame_out), cv2.COLOR_RGB2BGR))
writer.release()
out_info = f"Done — {total_frames} frames, {duration:.1f}s input → output at {OUTPUT_FPS}fps"
return temp_out, out_info
except Exception as e:
return None, f"Error: {str(e)}"
css = """
#col-container { margin: 0 auto; max-width: 1000px; }
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# SAM3 Video Segmentation — Green Mask")
gr.Markdown(
"Upload a video and describe what to segment. "
"Output is rendered at **24fps** with a **green mask** overlay."
)
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Input Video", format="mp4")
text_prompt = gr.Textbox(
label="Text Prompt",
placeholder="e.g., person, red car, dog",
)
run_btn = gr.Button("Segment Video", variant="primary", size="lg")
with gr.Column():
video_output = gr.Video(label="Segmented Video", autoplay=True)
status_box = gr.Textbox(label="Status", interactive=False)
run_btn.click(
fn=run_video_segmentation,
inputs=[video_input, text_prompt],
outputs=[video_output, status_box],
)
if __name__ == "__main__":
demo.launch(show_error=True)