import random import cv2 import torch import numpy as np def sample_frames_start_end(num_frames, start, end, sample='rand', fix_start=None): acc_samples = min(num_frames, end) intervals = np.linspace(start=start, stop=end, num=acc_samples + 1).astype(int) ranges = [(interv, intervals[idx + 1] - 1) for idx, interv in enumerate(intervals[:-1])] if sample == 'rand': frame_idxs = [random.choice(range(x[0], x[1])) for x in ranges] elif fix_start is not None: frame_idxs = [x[0] + fix_start for x in ranges] elif sample == 'uniform': frame_idxs = [(x[0] + x[1]) // 2 for x in ranges] else: raise NotImplementedError return frame_idxs def read_frames_cv2_egoclip( video_path, num_frames, sample, ): cap = cv2.VideoCapture(video_path) vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) assert (cap.isOpened()) # get indexes of sampled frames start_f = 0 end_f = vlen frame_idxs = sample_frames_start_end(num_frames, start_f, end_f, sample=sample) frames = [] for index in frame_idxs: _index = index % (600 * 30) _index = min(_index, vlen) cap.set(cv2.CAP_PROP_POS_FRAMES, _index - 1) ret, frame = cap.read() if ret: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = torch.from_numpy(frame) # (H x W x C) to (C x H x W) frame = frame.permute(2, 0, 1) frames.append(frame) while len(frames) < num_frames: # complete the frame frames.append(frames[-1]) frames = torch.stack(frames).float() / 255 cap.release() return frames