Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import numpy as np | |
import gradio as gr | |
from moviepy import VideoFileClip | |
import torch | |
import clip | |
import cv2 | |
from PIL import Image | |
from scenedetect import VideoManager, SceneManager | |
from scenedetect.detectors import ContentDetector, AdaptiveDetector, ThresholdDetector, HistogramDetector, HashDetector | |
# Device options | |
DEVICE_OPTIONS = { | |
"cpu": "cpu", | |
"cuda": "cuda" if torch.cuda.is_available() else "cpu", | |
"mps": "mps" if torch.backends.mps.is_available() else "cpu" | |
} | |
def load_clip_model(device): | |
return clip.load("ViT-B/32", device=device) | |
# --- Video Processing Functions --- | |
def extract_frames(video_path, fps=2): | |
cap = cv2.VideoCapture(video_path) | |
frames = [] | |
frame_rate = int(cap.get(cv2.CAP_PROP_FPS) / fps) | |
count = 0 | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if count % frame_rate == 0: | |
frames.append(frame) | |
count += 1 | |
cap.release() | |
return frames | |
def get_clip_features(frames, model, preprocess, device): | |
features = [] | |
for frame in frames: | |
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
img_input = preprocess(img).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
feature = model.encode_image(img_input) | |
features.append(feature.cpu().numpy()[0]) | |
return features | |
def compute_distance(a, b, method): | |
if method == "cosine": | |
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) | |
elif method == "l2": | |
return np.linalg.norm(a - b) | |
elif method == "l1": | |
return np.sum(np.abs(a - b)) | |
else: | |
return np.linalg.norm(a - b) | |
def find_match(clip_feats, ref_feats, threshold=0.3, similarity="l2"): | |
len_clip = len(clip_feats) | |
best_match = -1 | |
best_score = float('inf') if similarity != "cosine" else -float('inf') | |
for i in range(len(ref_feats) - len_clip + 1): | |
window = ref_feats[i:i + len_clip] | |
dists = [compute_distance(a, b, similarity) for a, b in zip(clip_feats, window)] | |
dist = np.mean(dists) | |
if (similarity != "cosine" and dist < best_score) or (similarity == "cosine" and dist > best_score): | |
best_score = dist | |
best_match = i | |
if (similarity != "cosine" and best_score < threshold) or (similarity == "cosine" and best_score > threshold): | |
return best_match, best_score | |
return -1, best_score | |
# Scene Detection | |
def get_detector(detector_name, threshold): | |
if detector_name == "ContentDetector": | |
return ContentDetector(threshold=threshold) | |
elif detector_name == "AdaptiveDetector": | |
return AdaptiveDetector() | |
elif detector_name == "ThresholdDetector": | |
return ThresholdDetector(threshold=threshold) | |
elif detector_name == "HashDetector": | |
return HashDetector(threshold=threshold) | |
elif detector_name == "HistogramDetector": | |
return HistogramDetector(threshold=threshold) | |
else: | |
return ContentDetector(threshold=threshold) | |
def detect_scenes(video_path, detector_name, threshold): | |
video_manager = VideoManager([video_path]) | |
scene_manager = SceneManager() | |
detector = get_detector(detector_name, threshold) | |
scene_manager.add_detector(detector) | |
video_manager.set_downscale_factor() | |
video_manager.start() | |
scene_manager.detect_scenes(frame_source=video_manager) | |
scene_list = scene_manager.get_scene_list() | |
return [(scene[0].get_seconds(), scene[1].get_seconds()) for scene in scene_list] | |
def find_scene_for_timestamp(scenes, match_time): | |
for start, end in scenes: | |
if start <= match_time <= end: | |
return (start, end) | |
return None | |
def extract_scene(video_path, scene_range, output_path): | |
start_time, end_time = scene_range | |
clip = VideoFileClip(video_path).subclipped(start_time, end_time) | |
clip.write_videofile(output_path, codec="libx264", audio_codec="aac") | |
return output_path | |
# Main logic | |
def process_videos(clip_path, ref_path, match_threshold, scene_threshold, detector_type, similarity_type, device_type, output_path): | |
device = DEVICE_OPTIONS.get(device_type, "cpu") | |
model, preprocess = load_clip_model(device) | |
clip_frames = extract_frames(clip_path) | |
ref_frames = extract_frames(ref_path) | |
clip_feats = get_clip_features(clip_frames, model, preprocess, device) | |
ref_feats = get_clip_features(ref_frames, model, preprocess, device) | |
match_index, score = find_match(clip_feats, ref_feats, match_threshold, similarity_type) | |
if match_index == -1: | |
return f"No match found (best score = {score:.4f})", None | |
match_time = match_index * 0.5 | |
scenes = detect_scenes(ref_path, detector_type, scene_threshold) | |
matched_scene = find_scene_for_timestamp(scenes, match_time) | |
if not matched_scene: | |
return "Match found, but no scene boundaries detected.", None | |
output_path = os.path.join(output_path, "matched_scene.mp4") | |
result_path = extract_scene(ref_path, matched_scene, output_path) | |
return f"Match found at ~{match_time:.2f}s (score = {score:.4f})\nScene from {matched_scene[0]:.2f}s to {matched_scene[1]:.2f}s", result_path | |
# Gradio Interface | |
with tempfile.TemporaryDirectory() as tmpdir: | |
iface = gr.Interface( | |
fn=process_videos, | |
inputs=[ | |
gr.Video(label="Clip Video"), | |
gr.Video(label="Reference Video"), | |
gr.Slider(0.1, 100.0, value=0.3, label="Matching Threshold (lower = stricter, cosine = higher = better)"), | |
gr.Slider(0.01, 100, value=30, step=1, label="Scene Detection Threshold"), | |
gr.Dropdown([ | |
"ContentDetector", "AdaptiveDetector", "ThresholdDetector", "HistogramDetector", "HashDetector" | |
], value="ContentDetector", label="Scene Detector Type"), | |
gr.Dropdown(["l2", "l1", "cosine"], value="l2", label="Similarity Metric"), | |
gr.Dropdown(["cpu", "cuda", "mps"], value="cpu", label="Processing Device"), | |
gr.Text(value=tmpdir,visible=False) | |
], | |
outputs=[ | |
gr.Text(label="Match Info"), | |
gr.Video(label="Matched Scene") | |
], | |
title="AI Video Clip Matcher", | |
description="Upload a short video clip and a reference video. The system will try to find where the clip appears in the reference video and extract the full scene around it." | |
) | |
# --- Launch the App --- | |
if __name__ == "__main__": | |
print("Launching Gradio interface...") | |
# set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values | |
# use `GRADIO_SERVER_NAME=0.0.0.0` for Docker | |
iface.launch() | |