vlm-rlaif-demo / processing_utils.py
SNUMPR's picture
Upload folder using huggingface_hub
598d165 verified
raw
history blame
4.22 kB
import torch
from transformers import TextStreamer
import numpy as np
import os
import json
import torch
import numpy as np
import base64
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
from torchvision.transforms import Compose, Lambda, ToTensor
from torchvision import transforms
from transformers import ProcessorMixin, BatchEncoding
from transformers.image_processing_utils import BatchFeature
from pytorchvideo.data.encoded_video import EncodedVideo
from torchvision.transforms import Compose, Lambda, ToTensor
from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo, CenterCropVideo
from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample
def load_frames(frames_dir):
results = []
frame_names = os.listdir(frames_dir)
frame_names.sort()
for frame_name in frame_names:
image_path = f"{frames_dir}/{frame_name}"
results.append(image_path)
return results
def sample_frames(frames, num_segments):
duration = len(frames)
frame_id_array = np.linspace(0, duration-1, num_segments, dtype=int)
frame_id_list = frame_id_array.tolist()
sampled_frames = []
for frame_idx in frame_id_list:
single_frame_path = frames[frame_idx]
sampled_frames.append(single_frame_path)
return sampled_frames
class VideoProcessor:
def __init__(self, image_transform):
self.image_transform = image_transform
def __call__(self, video_path, transform=None,
video_decode_backend='opencv',
clip_start_sec=0.0, clip_end_sec=None,
num_frames=50, **kwargs):
if transform is None: transform = self.image_transform
if video_decode_backend == 'pytorchvideo':
# decord pyav
video = EncodedVideo.from_path(video_path, decoder="decord", decode_audio=False)
duration = video.duration
start_sec = clip_start_sec # secs
end_sec = clip_end_sec if clip_end_sec is not None else duration # secs
video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)
video_outputs = transform(video_data)
elif video_decode_backend == 'decord':
import decord
from decord import VideoReader, cpu
decord.bridge.set_bridge('torch')
decord_vr = VideoReader(video_path, ctx=cpu(0))
ori_duration = len(decord_vr)
# frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int)
fps_vid = decord_vr.get_avg_fps()
valid_duration = min(int(fps_vid * 10), ori_duration)
frame_id_list = np.linspace(0, valid_duration-1, num_frames, dtype=int)
video_data = decord_vr.get_batch(frame_id_list)
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
video_outputs = transform(video_data)
elif video_decode_backend == 'opencv':
import cv2
cv2_vr = cv2.VideoCapture(video_path)
duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT))
frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int)
video_data = []
for frame_idx in frame_id_list:
cv2_vr.set(1, frame_idx)
_, frame = cv2_vr.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
video_data.append(torch.from_numpy(frame).permute(2, 0, 1))
cv2_vr.release()
video_data = torch.stack(video_data, dim=1)
video_outputs = transform(video_data)
elif video_decode_backend == 'frames':
# FIXME does not input start and end clip timestamps. Require duration info to deal with.
frames = load_frames(video_path)
frames = sample_frames(frames, num_frames)
to_tensor = ToTensor()
video_data = torch.stack([to_tensor(_) for _ in frames]).permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W)
else:
raise NameError('video_decode_backend should specify in (pytorchvideo, decord, opencv, frames)')