bluestarburst's picture
Upload folder using huggingface_hub
09bf9a3
raw
history blame
1.53 kB
import decord
decord.bridge.set_bridge('torch')
from typing import Callable, List, Optional, Union
from torch.utils.data import Dataset
from einops import rearrange
class MultiTuneAVideoDataset(Dataset):
def __init__(
self,
video_path: Union[str, list[str]],
prompt: Union[str, list[str]],
width: int = 512,
height: int = 512,
n_sample_frames: int = 8,
sample_start_idx: int = 0,
sample_frame_rate: int = 1,
):
self.video_path = [video_path] if isinstance(video_path, str) else video_path
self.prompt = [prompt] * len(self.video_path) if isinstance(prompt, str) else prompt
self.prompt_ids = []
self.width = width
self.height = height
self.n_sample_frames = n_sample_frames
self.sample_start_idx = sample_start_idx
self.sample_frame_rate = sample_frame_rate
def __len__(self):
return len(self.video_path)
def __getitem__(self, index):
# load and sample video frames
vr = decord.VideoReader(self.video_path[index], width=self.width, height=self.height)
sample_index = list(range(self.sample_start_idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames]
video = vr.get_batch(sample_index)
video = rearrange(video, "f h w c -> f c h w")
example = {
"pixel_values": (video / 127.5 - 1.0),
"prompt_ids": self.prompt_ids[index]
}
return example