|
import decord |
|
decord.bridge.set_bridge('torch') |
|
|
|
from typing import Callable, List, Optional, Union |
|
from torch.utils.data import Dataset |
|
from einops import rearrange |
|
|
|
|
|
class MultiTuneAVideoDataset(Dataset): |
|
def __init__( |
|
self, |
|
video_path: Union[str, list[str]], |
|
prompt: Union[str, list[str]], |
|
width: int = 512, |
|
height: int = 512, |
|
n_sample_frames: int = 8, |
|
sample_start_idx: int = 0, |
|
sample_frame_rate: int = 1, |
|
): |
|
self.video_path = [video_path] if isinstance(video_path, str) else video_path |
|
self.prompt = [prompt] * len(self.video_path) if isinstance(prompt, str) else prompt |
|
self.prompt_ids = [] |
|
|
|
self.width = width |
|
self.height = height |
|
self.n_sample_frames = n_sample_frames |
|
self.sample_start_idx = sample_start_idx |
|
self.sample_frame_rate = sample_frame_rate |
|
|
|
def __len__(self): |
|
return len(self.video_path) |
|
|
|
def __getitem__(self, index): |
|
|
|
vr = decord.VideoReader(self.video_path[index], width=self.width, height=self.height) |
|
sample_index = list(range(self.sample_start_idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames] |
|
video = vr.get_batch(sample_index) |
|
video = rearrange(video, "f h w c -> f c h w") |
|
|
|
example = { |
|
"pixel_values": (video / 127.5 - 1.0), |
|
"prompt_ids": self.prompt_ids[index] |
|
} |
|
|
|
return example |
|
|