import cv2 import torch import numpy as np from torchvision import transforms def extract_frames(video_path, num_frames=71, resize=(224, 224)): cap = cv2.VideoCapture(video_path) total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) interval = max(total // num_frames, 1) frames = [] for i in range(num_frames): cap.set(cv2.CAP_PROP_POS_FRAMES, i * interval) ret, frame = cap.read() if not ret: break frame = cv2.resize(frame, resize) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame / 255.0) cap.release() frames = np.stack(frames, axis=0) frames = torch.tensor(frames, dtype=torch.float32).permute(0, 3, 1, 2) # (T, C, H, W) return frames.unsqueeze(0) # (1, T, C, H, W)