Baraaqasem's picture
Upload 585 files
5d32408 verified
raw
history blame contribute delete
3.72 kB
import os
import cv2
from PIL import Image
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
def is_video(filename):
ext = os.path.splitext(filename)[-1].lower()
return ext in VID_EXTENSIONS
def extract_frames(
video_path,
frame_inds=None,
points=None,
backend="opencv",
return_length=False,
num_frames=None,
):
"""
Args:
video_path (str): path to video
frame_inds (List[int]): indices of frames to extract
points (List[float]): values within [0, 1); multiply #frames to get frame indices
Return:
List[PIL.Image]
"""
assert backend in ["av", "opencv", "decord"]
assert (frame_inds is None) or (points is None)
if backend == "av":
import av
container = av.open(video_path)
if num_frames is not None:
total_frames = num_frames
else:
total_frames = container.streams.video[0].frames
if points is not None:
frame_inds = [int(p * total_frames) for p in points]
frames = []
for idx in frame_inds:
if idx >= total_frames:
idx = total_frames - 1
target_timestamp = int(idx * av.time_base / container.streams.video[0].average_rate)
container.seek(target_timestamp)
frame = next(container.decode(video=0)).to_image()
frames.append(frame)
if return_length:
return frames, total_frames
return frames
elif backend == "decord":
import decord
container = decord.VideoReader(video_path, num_threads=1)
if num_frames is not None:
total_frames = num_frames
else:
total_frames = len(container)
if points is not None:
frame_inds = [int(p * total_frames) for p in points]
frame_inds = np.array(frame_inds).astype(np.int32)
frame_inds[frame_inds >= total_frames] = total_frames - 1
frames = container.get_batch(frame_inds).asnumpy() # [N, H, W, C]
frames = [Image.fromarray(x) for x in frames]
if return_length:
return frames, total_frames
return frames
elif backend == "opencv":
cap = cv2.VideoCapture(video_path)
if num_frames is not None:
total_frames = num_frames
else:
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if points is not None:
frame_inds = [int(p * total_frames) for p in points]
frames = []
for idx in frame_inds:
if idx >= total_frames:
idx = total_frames - 1
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
# HACK: sometimes OpenCV fails to read frames, return a black frame instead
try:
ret, frame = cap.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
except Exception as e:
print(f"Error reading frame {video_path}: {e}")
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame = Image.new("RGB", (width, height), (0, 0, 0))
# HACK: if height or width is 0, return a black frame instead
if frame.height == 0 or frame.width == 0:
height = width = 256
frame = Image.new("RGB", (width, height), (0, 0, 0))
frames.append(frame)
if return_length:
return frames, total_frames
return frames
else:
raise ValueError