omarabb315 commited on
Commit
5a0b543
·
verified ·
1 Parent(s): bce1b3b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import numpy as np
4
+ import gradio as gr
5
+ from moviepy import VideoFileClip
6
+ import torch
7
+ import clip
8
+ import cv2
9
+ from PIL import Image
10
+ from scenedetect import VideoManager, SceneManager
11
+ from scenedetect.detectors import ContentDetector, AdaptiveDetector, ThresholdDetector, HistogramDetector, HashDetector
12
+
13
+ # Device options
14
+ DEVICE_OPTIONS = {
15
+ "cpu": "cpu",
16
+ "cuda": "cuda" if torch.cuda.is_available() else "cpu",
17
+ "mps": "mps" if torch.backends.mps.is_available() else "cpu"
18
+ }
19
+
20
+ def load_clip_model(device):
21
+ return clip.load("ViT-B/32", device=device)
22
+
23
+ # --- Video Processing Functions ---
24
+ def extract_frames(video_path, fps=2):
25
+ cap = cv2.VideoCapture(video_path)
26
+ frames = []
27
+ frame_rate = int(cap.get(cv2.CAP_PROP_FPS) / fps)
28
+ count = 0
29
+ while cap.isOpened():
30
+ ret, frame = cap.read()
31
+ if not ret:
32
+ break
33
+ if count % frame_rate == 0:
34
+ frames.append(frame)
35
+ count += 1
36
+ cap.release()
37
+ return frames
38
+
39
+ def get_clip_features(frames, model, preprocess, device):
40
+ features = []
41
+ for frame in frames:
42
+ img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
43
+ img_input = preprocess(img).unsqueeze(0).to(device)
44
+ with torch.no_grad():
45
+ feature = model.encode_image(img_input)
46
+ features.append(feature.cpu().numpy()[0])
47
+ return features
48
+
49
+ def compute_distance(a, b, method):
50
+ if method == "cosine":
51
+ return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
52
+ elif method == "l2":
53
+ return np.linalg.norm(a - b)
54
+ elif method == "l1":
55
+ return np.sum(np.abs(a - b))
56
+ else:
57
+ return np.linalg.norm(a - b)
58
+
59
+ def find_match(clip_feats, ref_feats, threshold=0.3, similarity="l2"):
60
+ len_clip = len(clip_feats)
61
+ best_match = -1
62
+ best_score = float('inf') if similarity != "cosine" else -float('inf')
63
+ for i in range(len(ref_feats) - len_clip + 1):
64
+ window = ref_feats[i:i + len_clip]
65
+ dists = [compute_distance(a, b, similarity) for a, b in zip(clip_feats, window)]
66
+ dist = np.mean(dists)
67
+ if (similarity != "cosine" and dist < best_score) or (similarity == "cosine" and dist > best_score):
68
+ best_score = dist
69
+ best_match = i
70
+ if (similarity != "cosine" and best_score < threshold) or (similarity == "cosine" and best_score > threshold):
71
+ return best_match, best_score
72
+ return -1, best_score
73
+
74
+ # Scene Detection
75
+ def get_detector(detector_name, threshold):
76
+ if detector_name == "ContentDetector":
77
+ return ContentDetector(threshold=threshold)
78
+ elif detector_name == "AdaptiveDetector":
79
+ return AdaptiveDetector()
80
+ elif detector_name == "ThresholdDetector":
81
+ return ThresholdDetector(threshold=threshold)
82
+ elif detector_name == "HashDetector":
83
+ return HashDetector(threshold=threshold)
84
+ elif detector_name == "HistogramDetector":
85
+ return HistogramDetector(threshold=threshold)
86
+ else:
87
+ return ContentDetector(threshold=threshold)
88
+
89
+ def detect_scenes(video_path, detector_name, threshold):
90
+ video_manager = VideoManager([video_path])
91
+ scene_manager = SceneManager()
92
+ detector = get_detector(detector_name, threshold)
93
+ scene_manager.add_detector(detector)
94
+ video_manager.set_downscale_factor()
95
+ video_manager.start()
96
+ scene_manager.detect_scenes(frame_source=video_manager)
97
+ scene_list = scene_manager.get_scene_list()
98
+ return [(scene[0].get_seconds(), scene[1].get_seconds()) for scene in scene_list]
99
+
100
+ def find_scene_for_timestamp(scenes, match_time):
101
+ for start, end in scenes:
102
+ if start <= match_time <= end:
103
+ return (start, end)
104
+ return None
105
+
106
+ def extract_scene(video_path, scene_range, output_path):
107
+ start_time, end_time = scene_range
108
+ clip = VideoFileClip(video_path).subclipped(start_time, end_time)
109
+ clip.write_videofile(output_path, codec="libx264", audio_codec="aac")
110
+ return output_path
111
+
112
+ # Main logic
113
+
114
+ def process_videos(clip_path, ref_path, match_threshold, scene_threshold, detector_type, similarity_type, device_type, output_path):
115
+ device = DEVICE_OPTIONS.get(device_type, "cpu")
116
+ model, preprocess = load_clip_model(device)
117
+
118
+ clip_frames = extract_frames(clip_path)
119
+ ref_frames = extract_frames(ref_path)
120
+
121
+ clip_feats = get_clip_features(clip_frames, model, preprocess, device)
122
+ ref_feats = get_clip_features(ref_frames, model, preprocess, device)
123
+
124
+ match_index, score = find_match(clip_feats, ref_feats, match_threshold, similarity_type)
125
+
126
+ if match_index == -1:
127
+ return f"No match found (best score = {score:.4f})", None
128
+
129
+ match_time = match_index * 0.5
130
+ scenes = detect_scenes(ref_path, detector_type, scene_threshold)
131
+ matched_scene = find_scene_for_timestamp(scenes, match_time)
132
+
133
+ if not matched_scene:
134
+ return "Match found, but no scene boundaries detected.", None
135
+ output_path = os.path.join(output_path, "matched_scene.mp4")
136
+ result_path = extract_scene(ref_path, matched_scene, output_path)
137
+
138
+ return f"Match found at ~{match_time:.2f}s (score = {score:.4f})\nScene from {matched_scene[0]:.2f}s to {matched_scene[1]:.2f}s", result_path
139
+
140
+ # Gradio Interface
141
+ with tempfile.TemporaryDirectory() as tmpdir:
142
+ iface = gr.Interface(
143
+ fn=process_videos,
144
+ inputs=[
145
+ gr.Video(label="Clip Video"),
146
+ gr.Video(label="Reference Video"),
147
+ gr.Slider(0.1, 100.0, value=0.3, label="Matching Threshold (lower = stricter, cosine = higher = better)"),
148
+ gr.Slider(0.01, 100, value=30, step=1, label="Scene Detection Threshold"),
149
+ gr.Dropdown([
150
+ "ContentDetector", "AdaptiveDetector", "ThresholdDetector", "HistogramDetector", "HashDetector"
151
+ ], value="ContentDetector", label="Scene Detector Type"),
152
+ gr.Dropdown(["l2", "l1", "cosine"], value="l2", label="Similarity Metric"),
153
+ gr.Dropdown(["cpu", "cuda", "mps"], value="cpu", label="Processing Device"),
154
+ gr.Text(value=tmpdir,visible=False)
155
+ ],
156
+ outputs=[
157
+ gr.Text(label="Match Info"),
158
+ gr.Video(label="Matched Scene")
159
+ ],
160
+ title="AI Video Clip Matcher",
161
+ description="Upload a short video clip and a reference video. The system will try to find where the clip appears in the reference video and extract the full scene around it."
162
+ )
163
+
164
+ # --- Launch the App ---
165
+ if __name__ == "__main__":
166
+ print("Launching Gradio interface...")
167
+
168
+ # set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values
169
+ # use `GRADIO_SERVER_NAME=0.0.0.0` for Docker
170
+ iface.launch()