Update app.py
Browse files
app.py
CHANGED
@@ -6,9 +6,16 @@ import cv2 # OpenCV for video processing
|
|
6 |
# Model ID for video classification (UCF101 subset)
|
7 |
model_id = "MCG-NJU/videomae-base"
|
8 |
|
|
|
|
|
|
|
|
|
9 |
def analyze_video(video):
|
10 |
# Extract key frames from the video using OpenCV
|
11 |
-
frames = extract_key_frames(video)
|
|
|
|
|
|
|
12 |
|
13 |
# Load model and feature extractor manually
|
14 |
model = VideoMAEForVideoClassification.from_pretrained(model_id)
|
@@ -35,16 +42,21 @@ def analyze_video(video):
|
|
35 |
|
36 |
return final_result
|
37 |
|
38 |
-
def extract_key_frames(video):
|
39 |
cap = cv2.VideoCapture(video)
|
40 |
frames = []
|
41 |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
42 |
-
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
43 |
|
44 |
-
for
|
|
|
|
|
|
|
|
|
45 |
ret, frame = cap.read()
|
46 |
-
if ret
|
47 |
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Convert to RGB
|
|
|
|
|
48 |
|
49 |
cap.release()
|
50 |
return frames
|
|
|
6 |
# Model ID for video classification (UCF101 subset)
|
7 |
model_id = "MCG-NJU/videomae-base"
|
8 |
|
9 |
+
# Parameters for frame extraction
|
10 |
+
TARGET_FRAME_COUNT = 16
|
11 |
+
FRAME_SIZE = (224, 224) # Expected frame size for the model
|
12 |
+
|
13 |
def analyze_video(video):
|
14 |
# Extract key frames from the video using OpenCV
|
15 |
+
frames = extract_key_frames(video, TARGET_FRAME_COUNT)
|
16 |
+
|
17 |
+
# Resize frames to the expected size
|
18 |
+
frames = [cv2.resize(frame, FRAME_SIZE) for frame in frames]
|
19 |
|
20 |
# Load model and feature extractor manually
|
21 |
model = VideoMAEForVideoClassification.from_pretrained(model_id)
|
|
|
42 |
|
43 |
return final_result
|
44 |
|
45 |
+
def extract_key_frames(video, target_frame_count):
|
46 |
cap = cv2.VideoCapture(video)
|
47 |
frames = []
|
48 |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
49 |
|
50 |
+
# Calculate interval for frame extraction
|
51 |
+
interval = max(1, frame_count // target_frame_count)
|
52 |
+
|
53 |
+
for i in range(0, frame_count, interval):
|
54 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
55 |
ret, frame = cap.read()
|
56 |
+
if ret:
|
57 |
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Convert to RGB
|
58 |
+
if len(frames) >= target_frame_count:
|
59 |
+
break
|
60 |
|
61 |
cap.release()
|
62 |
return frames
|