|
|
""" |
|
|
SAM2 Video Segmentation Space |
|
|
Removes background from videos by tracking specified objects. |
|
|
Provides both Gradio UI and API endpoints. |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import tempfile |
|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import List, Tuple, Optional, Dict, Any |
|
|
from transformers import Sam2VideoModel, Sam2VideoProcessor |
|
|
from transformers.video_utils import load_video |
|
|
from PIL import Image |
|
|
import json |
|
|
|
|
|
|
|
|
MODEL_NAME = "facebook/sam2.1-hiera-tiny" |
|
|
device = None |
|
|
model = None |
|
|
processor = None |
|
|
|
|
|
|
|
|
def initialize_model(): |
|
|
"""Initialize SAM2 model and processor.""" |
|
|
global device, model, processor |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device("cuda") |
|
|
dtype = torch.float16 |
|
|
elif torch.backends.mps.is_available(): |
|
|
device = torch.device("mps") |
|
|
dtype = torch.float32 |
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
dtype = torch.float32 |
|
|
|
|
|
print(f"Loading SAM2 model on {device}...") |
|
|
|
|
|
|
|
|
model = Sam2VideoModel.from_pretrained(MODEL_NAME).to(device, dtype=dtype) |
|
|
processor = Sam2VideoProcessor.from_pretrained(MODEL_NAME) |
|
|
|
|
|
print("Model loaded successfully!") |
|
|
return device, model, processor |
|
|
|
|
|
|
|
|
def extract_frames_from_video(video_path: str, max_frames: Optional[int] = None) -> Tuple[List[Image.Image], Dict]: |
|
|
"""Extract frames from video file.""" |
|
|
video_frames, info = load_video(video_path) |
|
|
|
|
|
if max_frames and len(video_frames) > max_frames: |
|
|
|
|
|
indices = np.linspace(0, len(video_frames) - 1, max_frames, dtype=int) |
|
|
video_frames = [video_frames[i] for i in indices] |
|
|
|
|
|
return video_frames, info |
|
|
|
|
|
|
|
|
def create_output_video( |
|
|
video_frames: List[Image.Image], |
|
|
masks: Dict[int, torch.Tensor], |
|
|
output_path: str, |
|
|
fps: float = 30.0, |
|
|
remove_background: bool = True |
|
|
) -> str: |
|
|
""" |
|
|
Create output video with segmented objects. |
|
|
|
|
|
Args: |
|
|
video_frames: Original video frames |
|
|
masks: Dictionary mapping frame_idx to mask tensors |
|
|
output_path: Path to save output video |
|
|
fps: Frames per second |
|
|
remove_background: If True, remove background; if False, highlight objects |
|
|
""" |
|
|
if not masks: |
|
|
raise ValueError("No masks provided") |
|
|
|
|
|
|
|
|
first_frame = np.array(video_frames[0]) |
|
|
height, width = first_frame.shape[:2] |
|
|
|
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
|
|
for frame_idx, frame_pil in enumerate(video_frames): |
|
|
frame = np.array(frame_pil) |
|
|
|
|
|
if frame_idx in masks: |
|
|
mask = masks[frame_idx].cpu().numpy() |
|
|
|
|
|
|
|
|
if mask.ndim == 4: |
|
|
mask = mask[0] |
|
|
if mask.ndim == 3: |
|
|
|
|
|
mask = mask.max(axis=0) |
|
|
|
|
|
|
|
|
if mask.shape != (height, width): |
|
|
mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
|
|
|
mask_binary = (mask > 0.5).astype(np.uint8) |
|
|
|
|
|
if remove_background: |
|
|
|
|
|
if frame.shape[2] == 3: |
|
|
|
|
|
result = np.zeros((height, width, 4), dtype=np.uint8) |
|
|
result[:, :, :3] = frame |
|
|
result[:, :, 3] = mask_binary * 255 |
|
|
|
|
|
|
|
|
background = np.zeros_like(frame) |
|
|
mask_3d = np.repeat(mask_binary[:, :, np.newaxis], 3, axis=2) |
|
|
result_rgb = frame * mask_3d + background * (1 - mask_3d) |
|
|
frame = result_rgb.astype(np.uint8) |
|
|
else: |
|
|
|
|
|
overlay = frame.copy() |
|
|
overlay[mask_binary > 0] = [0, 255, 0] |
|
|
frame = cv2.addWeighted(frame, 0.7, overlay, 0.3, 0) |
|
|
|
|
|
|
|
|
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
|
out.write(frame_bgr) |
|
|
|
|
|
out.release() |
|
|
return output_path |
|
|
|
|
|
|
|
|
def segment_video( |
|
|
video_path: str, |
|
|
annotations: List[Dict[str, Any]], |
|
|
remove_background: bool = True, |
|
|
max_frames: Optional[int] = None |
|
|
) -> str: |
|
|
""" |
|
|
Main function to segment video based on annotations. |
|
|
|
|
|
Args: |
|
|
video_path: Path to input video |
|
|
annotations: List of annotation dictionaries with format: |
|
|
[ |
|
|
{ |
|
|
"frame_idx": 0, |
|
|
"object_id": 1, |
|
|
"points": [[x1, y1], [x2, y2], ...], |
|
|
"labels": [1, 1, ...] # 1 for foreground, 0 for background |
|
|
}, |
|
|
... |
|
|
] |
|
|
remove_background: If True, remove background; if False, highlight objects |
|
|
max_frames: Maximum number of frames to process (None = all frames) |
|
|
|
|
|
Returns: |
|
|
Path to output video file |
|
|
""" |
|
|
global device, model, processor |
|
|
|
|
|
if model is None: |
|
|
initialize_model() |
|
|
|
|
|
|
|
|
print("Loading video frames...") |
|
|
video_frames, video_info = extract_frames_from_video(video_path, max_frames) |
|
|
fps = video_info.get('fps', 30.0) |
|
|
|
|
|
print(f"Processing {len(video_frames)} frames at {fps} FPS") |
|
|
|
|
|
|
|
|
dtype = torch.float16 if device.type == "cuda" else torch.float32 |
|
|
inference_session = processor.init_video_session( |
|
|
video=video_frames, |
|
|
inference_device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
|
|
|
print("Adding annotations...") |
|
|
for ann in annotations: |
|
|
frame_idx = ann["frame_idx"] |
|
|
obj_id = ann["object_id"] |
|
|
points = ann.get("points", []) |
|
|
labels = ann.get("labels", [1] * len(points)) |
|
|
|
|
|
if points: |
|
|
|
|
|
formatted_points = [[points]] |
|
|
formatted_labels = [[labels]] |
|
|
|
|
|
processor.add_inputs_to_inference_session( |
|
|
inference_session=inference_session, |
|
|
frame_idx=frame_idx, |
|
|
obj_ids=obj_id, |
|
|
input_points=formatted_points, |
|
|
input_labels=formatted_labels, |
|
|
) |
|
|
|
|
|
|
|
|
outputs = model( |
|
|
inference_session=inference_session, |
|
|
frame_idx=frame_idx, |
|
|
) |
|
|
|
|
|
|
|
|
print("Propagating masks through video...") |
|
|
video_segments = {} |
|
|
|
|
|
for sam2_output in model.propagate_in_video_iterator(inference_session): |
|
|
video_res_masks = processor.post_process_masks( |
|
|
[sam2_output.pred_masks], |
|
|
original_sizes=[[inference_session.video_height, inference_session.video_width]], |
|
|
binarize=False |
|
|
)[0] |
|
|
video_segments[sam2_output.frame_idx] = video_res_masks |
|
|
|
|
|
print(f"Generated masks for {len(video_segments)} frames") |
|
|
|
|
|
|
|
|
output_path = tempfile.mktemp(suffix=".mp4") |
|
|
print("Creating output video...") |
|
|
create_output_video(video_frames, video_segments, output_path, fps, remove_background) |
|
|
|
|
|
print(f"Output video saved to: {output_path}") |
|
|
return output_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_segment_video( |
|
|
video_file, |
|
|
annotation_json: str, |
|
|
remove_bg: bool = True, |
|
|
max_frames: Optional[int] = None |
|
|
): |
|
|
""" |
|
|
Gradio wrapper for video segmentation. |
|
|
|
|
|
Args: |
|
|
video_file: Uploaded video file |
|
|
annotation_json: JSON string with annotations |
|
|
remove_bg: Whether to remove background |
|
|
max_frames: Maximum frames to process |
|
|
""" |
|
|
try: |
|
|
|
|
|
annotations = json.loads(annotation_json) |
|
|
|
|
|
if not isinstance(annotations, list): |
|
|
return None, "Error: Annotations must be a list of objects" |
|
|
|
|
|
|
|
|
output_path = segment_video( |
|
|
video_path=video_file, |
|
|
annotations=annotations, |
|
|
remove_background=remove_bg, |
|
|
max_frames=max_frames |
|
|
) |
|
|
|
|
|
return output_path, "✅ Video processed successfully!" |
|
|
|
|
|
except json.JSONDecodeError as e: |
|
|
return None, f"❌ JSON parsing error: {str(e)}" |
|
|
except Exception as e: |
|
|
return None, f"❌ Error: {str(e)}" |
|
|
|
|
|
|
|
|
def gradio_simple_segment( |
|
|
video_file, |
|
|
point_x: int, |
|
|
point_y: int, |
|
|
frame_idx: int = 0, |
|
|
remove_bg: bool = True, |
|
|
max_frames: Optional[int] = 300 |
|
|
): |
|
|
""" |
|
|
Simple Gradio interface with single point annotation. |
|
|
""" |
|
|
try: |
|
|
|
|
|
annotations = [{ |
|
|
"frame_idx": frame_idx, |
|
|
"object_id": 1, |
|
|
"points": [[point_x, point_y]], |
|
|
"labels": [1] |
|
|
}] |
|
|
|
|
|
|
|
|
output_path = segment_video( |
|
|
video_path=video_file, |
|
|
annotations=annotations, |
|
|
remove_background=remove_bg, |
|
|
max_frames=max_frames |
|
|
) |
|
|
|
|
|
return output_path, f"✅ Video processed! Tracked from point ({point_x}, {point_y}) on frame {frame_idx}" |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"❌ Error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def api_segment_video(video_file, annotations_json: str, remove_background: bool = True, max_frames: int = None): |
|
|
""" |
|
|
API endpoint for video segmentation. |
|
|
Can be called via gradio_client or direct HTTP requests. |
|
|
""" |
|
|
annotations = json.loads(annotations_json) |
|
|
output_path = segment_video(video_file, annotations, remove_background, max_frames) |
|
|
return output_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
"""Create the Gradio interface.""" |
|
|
|
|
|
|
|
|
initialize_model() |
|
|
|
|
|
|
|
|
with gr.Blocks(title="SAM2 Video Segmentation - Remove Background") as app: |
|
|
gr.Markdown(""" |
|
|
# 🎥 SAM2 Video Background Remover |
|
|
|
|
|
Remove backgrounds from videos by tracking objects. Uses Meta's Segment Anything Model 2 (SAM2). |
|
|
|
|
|
**Two ways to use this:** |
|
|
1. **Simple Mode**: Click on an object in the first frame |
|
|
2. **Advanced Mode**: Provide detailed JSON annotations |
|
|
3. **API Mode**: Use the API endpoint programmatically |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("Simple Mode"): |
|
|
gr.Markdown(""" |
|
|
### Quick Start |
|
|
1. Upload a video |
|
|
2. Specify the coordinates of the object you want to track |
|
|
3. Click "Process Video" |
|
|
|
|
|
**Tip**: Open your video in an image viewer to find the x,y coordinates of your target object in the first frame. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
simple_video_input = gr.Video(label="Upload Video") |
|
|
|
|
|
with gr.Row(): |
|
|
point_x_input = gr.Number(label="Point X", value=320, precision=0) |
|
|
point_y_input = gr.Number(label="Point Y", value=240, precision=0) |
|
|
|
|
|
frame_idx_input = gr.Number(label="Frame Index", value=0, precision=0, |
|
|
info="Which frame to annotate (usually 0 for first frame)") |
|
|
|
|
|
remove_bg_simple = gr.Checkbox(label="Remove Background", value=True, |
|
|
info="If checked, removes background. If unchecked, highlights object.") |
|
|
|
|
|
max_frames_simple = gr.Number(label="Max Frames (optional)", value=300, precision=0, |
|
|
info="Limit frames for faster processing. Leave at 0 for all frames.") |
|
|
|
|
|
simple_btn = gr.Button("🎬 Process Video", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
simple_output_video = gr.Video(label="Output Video") |
|
|
simple_status = gr.Textbox(label="Status", lines=3) |
|
|
|
|
|
simple_btn.click( |
|
|
fn=gradio_simple_segment, |
|
|
inputs=[simple_video_input, point_x_input, point_y_input, frame_idx_input, |
|
|
remove_bg_simple, max_frames_simple], |
|
|
outputs=[simple_output_video, simple_status] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### Example: |
|
|
For a 640x480 video with a person in the center, try: X=320, Y=240, Frame=0 |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Tab("Advanced Mode (JSON)"): |
|
|
gr.Markdown(""" |
|
|
### Advanced Annotations |
|
|
Provide detailed JSON annotations for multiple objects and frames. |
|
|
|
|
|
**JSON Format:** |
|
|
```json |
|
|
[ |
|
|
{ |
|
|
"frame_idx": 0, |
|
|
"object_id": 1, |
|
|
"points": [[x1, y1], [x2, y2]], |
|
|
"labels": [1, 1] |
|
|
} |
|
|
] |
|
|
``` |
|
|
|
|
|
- `frame_idx`: Frame number to annotate |
|
|
- `object_id`: Unique ID for each object (1, 2, 3, ...) |
|
|
- `points`: List of [x, y] coordinates |
|
|
- `labels`: 1 for foreground point, 0 for background point |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
adv_video_input = gr.Video(label="Upload Video") |
|
|
|
|
|
adv_annotation_input = gr.Textbox( |
|
|
label="Annotations (JSON)", |
|
|
lines=10, |
|
|
value='''[ |
|
|
{ |
|
|
"frame_idx": 0, |
|
|
"object_id": 1, |
|
|
"points": [[320, 240]], |
|
|
"labels": [1] |
|
|
} |
|
|
]''', |
|
|
placeholder="Enter JSON annotations here..." |
|
|
) |
|
|
|
|
|
remove_bg_adv = gr.Checkbox(label="Remove Background", value=True) |
|
|
max_frames_adv = gr.Number(label="Max Frames (0 = all)", value=0, precision=0) |
|
|
|
|
|
adv_btn = gr.Button("🎬 Process Video", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
adv_output_video = gr.Video(label="Output Video") |
|
|
adv_status = gr.Textbox(label="Status", lines=3) |
|
|
|
|
|
adv_btn.click( |
|
|
fn=gradio_segment_video, |
|
|
inputs=[adv_video_input, adv_annotation_input, remove_bg_adv, max_frames_adv], |
|
|
outputs=[adv_output_video, adv_status] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("API Documentation"): |
|
|
gr.Markdown(""" |
|
|
## 📡 API Usage |
|
|
|
|
|
This Space exposes an API that you can call programmatically. |
|
|
|
|
|
### Using Python with `gradio_client` |
|
|
|
|
|
```python |
|
|
from gradio_client import Client |
|
|
import json |
|
|
|
|
|
# Connect to the Space |
|
|
client = Client("YOUR_USERNAME/YOUR_SPACE_NAME") |
|
|
|
|
|
# Define annotations |
|
|
annotations = [ |
|
|
{ |
|
|
"frame_idx": 0, |
|
|
"object_id": 1, |
|
|
"points": [[320, 240]], |
|
|
"labels": [1] |
|
|
} |
|
|
] |
|
|
|
|
|
# Call the API |
|
|
result = client.predict( |
|
|
video_file="path/to/video.mp4", |
|
|
annotations_json=json.dumps(annotations), |
|
|
remove_background=True, |
|
|
max_frames=300, |
|
|
api_name="/segment_video_api" |
|
|
) |
|
|
|
|
|
print(f"Output video: {result}") |
|
|
``` |
|
|
|
|
|
### Using cURL |
|
|
|
|
|
```bash |
|
|
curl -X POST https://YOUR_USERNAME-YOUR_SPACE_NAME.hf.space/api/predict \\ |
|
|
-H "Content-Type: application/json" \\ |
|
|
-F "data=@video.mp4" \\ |
|
|
-F 'annotations=[{"frame_idx":0,"object_id":1,"points":[[320,240]],"labels":[1]}]' |
|
|
``` |
|
|
|
|
|
### Parameters |
|
|
|
|
|
- **video_file**: Video file (required) |
|
|
- **annotations_json**: JSON string with annotations (required) |
|
|
- **remove_background**: Boolean (default: true) |
|
|
- **max_frames**: Integer (default: null, processes all frames) |
|
|
|
|
|
### Response |
|
|
|
|
|
Returns the path to the processed video file. |
|
|
""") |
|
|
|
|
|
|
|
|
api_interface = gr.Interface( |
|
|
fn=api_segment_video, |
|
|
inputs=[ |
|
|
gr.Video(label="Video File"), |
|
|
gr.Textbox(label="Annotations JSON"), |
|
|
gr.Checkbox(label="Remove Background", value=True), |
|
|
gr.Number(label="Max Frames", value=None, precision=0) |
|
|
], |
|
|
outputs=gr.Video(label="Output Video"), |
|
|
api_name="segment_video_api", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app = create_interface() |
|
|
app.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False |
|
|
) |
|
|
|
|
|
|