Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
from moviepy import VideoFileClip
|
6 |
+
import torch
|
7 |
+
import clip
|
8 |
+
import cv2
|
9 |
+
from PIL import Image
|
10 |
+
from scenedetect import VideoManager, SceneManager
|
11 |
+
from scenedetect.detectors import ContentDetector, AdaptiveDetector, ThresholdDetector, HistogramDetector, HashDetector
|
12 |
+
|
13 |
+
# Device options
|
14 |
+
DEVICE_OPTIONS = {
|
15 |
+
"cpu": "cpu",
|
16 |
+
"cuda": "cuda" if torch.cuda.is_available() else "cpu",
|
17 |
+
"mps": "mps" if torch.backends.mps.is_available() else "cpu"
|
18 |
+
}
|
19 |
+
|
20 |
+
def load_clip_model(device):
|
21 |
+
return clip.load("ViT-B/32", device=device)
|
22 |
+
|
23 |
+
# --- Video Processing Functions ---
|
24 |
+
def extract_frames(video_path, fps=2):
|
25 |
+
cap = cv2.VideoCapture(video_path)
|
26 |
+
frames = []
|
27 |
+
frame_rate = int(cap.get(cv2.CAP_PROP_FPS) / fps)
|
28 |
+
count = 0
|
29 |
+
while cap.isOpened():
|
30 |
+
ret, frame = cap.read()
|
31 |
+
if not ret:
|
32 |
+
break
|
33 |
+
if count % frame_rate == 0:
|
34 |
+
frames.append(frame)
|
35 |
+
count += 1
|
36 |
+
cap.release()
|
37 |
+
return frames
|
38 |
+
|
39 |
+
def get_clip_features(frames, model, preprocess, device):
|
40 |
+
features = []
|
41 |
+
for frame in frames:
|
42 |
+
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
43 |
+
img_input = preprocess(img).unsqueeze(0).to(device)
|
44 |
+
with torch.no_grad():
|
45 |
+
feature = model.encode_image(img_input)
|
46 |
+
features.append(feature.cpu().numpy()[0])
|
47 |
+
return features
|
48 |
+
|
49 |
+
def compute_distance(a, b, method):
|
50 |
+
if method == "cosine":
|
51 |
+
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
52 |
+
elif method == "l2":
|
53 |
+
return np.linalg.norm(a - b)
|
54 |
+
elif method == "l1":
|
55 |
+
return np.sum(np.abs(a - b))
|
56 |
+
else:
|
57 |
+
return np.linalg.norm(a - b)
|
58 |
+
|
59 |
+
def find_match(clip_feats, ref_feats, threshold=0.3, similarity="l2"):
|
60 |
+
len_clip = len(clip_feats)
|
61 |
+
best_match = -1
|
62 |
+
best_score = float('inf') if similarity != "cosine" else -float('inf')
|
63 |
+
for i in range(len(ref_feats) - len_clip + 1):
|
64 |
+
window = ref_feats[i:i + len_clip]
|
65 |
+
dists = [compute_distance(a, b, similarity) for a, b in zip(clip_feats, window)]
|
66 |
+
dist = np.mean(dists)
|
67 |
+
if (similarity != "cosine" and dist < best_score) or (similarity == "cosine" and dist > best_score):
|
68 |
+
best_score = dist
|
69 |
+
best_match = i
|
70 |
+
if (similarity != "cosine" and best_score < threshold) or (similarity == "cosine" and best_score > threshold):
|
71 |
+
return best_match, best_score
|
72 |
+
return -1, best_score
|
73 |
+
|
74 |
+
# Scene Detection
|
75 |
+
def get_detector(detector_name, threshold):
|
76 |
+
if detector_name == "ContentDetector":
|
77 |
+
return ContentDetector(threshold=threshold)
|
78 |
+
elif detector_name == "AdaptiveDetector":
|
79 |
+
return AdaptiveDetector()
|
80 |
+
elif detector_name == "ThresholdDetector":
|
81 |
+
return ThresholdDetector(threshold=threshold)
|
82 |
+
elif detector_name == "HashDetector":
|
83 |
+
return HashDetector(threshold=threshold)
|
84 |
+
elif detector_name == "HistogramDetector":
|
85 |
+
return HistogramDetector(threshold=threshold)
|
86 |
+
else:
|
87 |
+
return ContentDetector(threshold=threshold)
|
88 |
+
|
89 |
+
def detect_scenes(video_path, detector_name, threshold):
|
90 |
+
video_manager = VideoManager([video_path])
|
91 |
+
scene_manager = SceneManager()
|
92 |
+
detector = get_detector(detector_name, threshold)
|
93 |
+
scene_manager.add_detector(detector)
|
94 |
+
video_manager.set_downscale_factor()
|
95 |
+
video_manager.start()
|
96 |
+
scene_manager.detect_scenes(frame_source=video_manager)
|
97 |
+
scene_list = scene_manager.get_scene_list()
|
98 |
+
return [(scene[0].get_seconds(), scene[1].get_seconds()) for scene in scene_list]
|
99 |
+
|
100 |
+
def find_scene_for_timestamp(scenes, match_time):
|
101 |
+
for start, end in scenes:
|
102 |
+
if start <= match_time <= end:
|
103 |
+
return (start, end)
|
104 |
+
return None
|
105 |
+
|
106 |
+
def extract_scene(video_path, scene_range, output_path):
|
107 |
+
start_time, end_time = scene_range
|
108 |
+
clip = VideoFileClip(video_path).subclipped(start_time, end_time)
|
109 |
+
clip.write_videofile(output_path, codec="libx264", audio_codec="aac")
|
110 |
+
return output_path
|
111 |
+
|
112 |
+
# Main logic
|
113 |
+
|
114 |
+
def process_videos(clip_path, ref_path, match_threshold, scene_threshold, detector_type, similarity_type, device_type, output_path):
|
115 |
+
device = DEVICE_OPTIONS.get(device_type, "cpu")
|
116 |
+
model, preprocess = load_clip_model(device)
|
117 |
+
|
118 |
+
clip_frames = extract_frames(clip_path)
|
119 |
+
ref_frames = extract_frames(ref_path)
|
120 |
+
|
121 |
+
clip_feats = get_clip_features(clip_frames, model, preprocess, device)
|
122 |
+
ref_feats = get_clip_features(ref_frames, model, preprocess, device)
|
123 |
+
|
124 |
+
match_index, score = find_match(clip_feats, ref_feats, match_threshold, similarity_type)
|
125 |
+
|
126 |
+
if match_index == -1:
|
127 |
+
return f"No match found (best score = {score:.4f})", None
|
128 |
+
|
129 |
+
match_time = match_index * 0.5
|
130 |
+
scenes = detect_scenes(ref_path, detector_type, scene_threshold)
|
131 |
+
matched_scene = find_scene_for_timestamp(scenes, match_time)
|
132 |
+
|
133 |
+
if not matched_scene:
|
134 |
+
return "Match found, but no scene boundaries detected.", None
|
135 |
+
output_path = os.path.join(output_path, "matched_scene.mp4")
|
136 |
+
result_path = extract_scene(ref_path, matched_scene, output_path)
|
137 |
+
|
138 |
+
return f"Match found at ~{match_time:.2f}s (score = {score:.4f})\nScene from {matched_scene[0]:.2f}s to {matched_scene[1]:.2f}s", result_path
|
139 |
+
|
140 |
+
# Gradio Interface
|
141 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
142 |
+
iface = gr.Interface(
|
143 |
+
fn=process_videos,
|
144 |
+
inputs=[
|
145 |
+
gr.Video(label="Clip Video"),
|
146 |
+
gr.Video(label="Reference Video"),
|
147 |
+
gr.Slider(0.1, 100.0, value=0.3, label="Matching Threshold (lower = stricter, cosine = higher = better)"),
|
148 |
+
gr.Slider(0.01, 100, value=30, step=1, label="Scene Detection Threshold"),
|
149 |
+
gr.Dropdown([
|
150 |
+
"ContentDetector", "AdaptiveDetector", "ThresholdDetector", "HistogramDetector", "HashDetector"
|
151 |
+
], value="ContentDetector", label="Scene Detector Type"),
|
152 |
+
gr.Dropdown(["l2", "l1", "cosine"], value="l2", label="Similarity Metric"),
|
153 |
+
gr.Dropdown(["cpu", "cuda", "mps"], value="cpu", label="Processing Device"),
|
154 |
+
gr.Text(value=tmpdir,visible=False)
|
155 |
+
],
|
156 |
+
outputs=[
|
157 |
+
gr.Text(label="Match Info"),
|
158 |
+
gr.Video(label="Matched Scene")
|
159 |
+
],
|
160 |
+
title="AI Video Clip Matcher",
|
161 |
+
description="Upload a short video clip and a reference video. The system will try to find where the clip appears in the reference video and extract the full scene around it."
|
162 |
+
)
|
163 |
+
|
164 |
+
# --- Launch the App ---
|
165 |
+
if __name__ == "__main__":
|
166 |
+
print("Launching Gradio interface...")
|
167 |
+
|
168 |
+
# set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values
|
169 |
+
# use `GRADIO_SERVER_NAME=0.0.0.0` for Docker
|
170 |
+
iface.launch()
|