MNGames commited on
Commit
e7b04ff
1 Parent(s): d2d1207

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -5
app.py CHANGED
@@ -6,9 +6,16 @@ import cv2 # OpenCV for video processing
6
  # Model ID for video classification (UCF101 subset)
7
  model_id = "MCG-NJU/videomae-base"
8
 
 
 
 
 
9
  def analyze_video(video):
10
  # Extract key frames from the video using OpenCV
11
- frames = extract_key_frames(video)
 
 
 
12
 
13
  # Load model and feature extractor manually
14
  model = VideoMAEForVideoClassification.from_pretrained(model_id)
@@ -35,16 +42,21 @@ def analyze_video(video):
35
 
36
  return final_result
37
 
38
- def extract_key_frames(video):
39
  cap = cv2.VideoCapture(video)
40
  frames = []
41
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
42
- fps = int(cap.get(cv2.CAP_PROP_FPS))
43
 
44
- for i in range(frame_count):
 
 
 
 
45
  ret, frame = cap.read()
46
- if ret and i % (fps // 2) == 0: # Extract a frame every half second
47
  frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Convert to RGB
 
 
48
 
49
  cap.release()
50
  return frames
 
6
  # Model ID for video classification (UCF101 subset)
7
  model_id = "MCG-NJU/videomae-base"
8
 
9
+ # Parameters for frame extraction
10
+ TARGET_FRAME_COUNT = 16
11
+ FRAME_SIZE = (224, 224) # Expected frame size for the model
12
+
13
  def analyze_video(video):
14
  # Extract key frames from the video using OpenCV
15
+ frames = extract_key_frames(video, TARGET_FRAME_COUNT)
16
+
17
+ # Resize frames to the expected size
18
+ frames = [cv2.resize(frame, FRAME_SIZE) for frame in frames]
19
 
20
  # Load model and feature extractor manually
21
  model = VideoMAEForVideoClassification.from_pretrained(model_id)
 
42
 
43
  return final_result
44
 
45
+ def extract_key_frames(video, target_frame_count):
46
  cap = cv2.VideoCapture(video)
47
  frames = []
48
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
49
 
50
+ # Calculate interval for frame extraction
51
+ interval = max(1, frame_count // target_frame_count)
52
+
53
+ for i in range(0, frame_count, interval):
54
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i)
55
  ret, frame = cap.read()
56
+ if ret:
57
  frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Convert to RGB
58
+ if len(frames) >= target_frame_count:
59
+ break
60
 
61
  cap.release()
62
  return frames