File size: 1,525 Bytes
09bf9a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import decord
decord.bridge.set_bridge('torch')

from typing import Callable, List, Optional, Union
from torch.utils.data import Dataset
from einops import rearrange


class MultiTuneAVideoDataset(Dataset):
    def __init__(
            self,
            video_path: Union[str, list[str]],
            prompt: Union[str, list[str]],
            width: int = 512,
            height: int = 512,
            n_sample_frames: int = 8,
            sample_start_idx: int = 0,
            sample_frame_rate: int = 1,
    ):
        self.video_path = [video_path] if isinstance(video_path, str) else video_path
        self.prompt = [prompt] * len(self.video_path) if isinstance(prompt, str) else prompt
        self.prompt_ids = []

        self.width = width
        self.height = height
        self.n_sample_frames = n_sample_frames
        self.sample_start_idx = sample_start_idx
        self.sample_frame_rate = sample_frame_rate

    def __len__(self):
        return len(self.video_path)

    def __getitem__(self, index):
        # load and sample video frames
        vr = decord.VideoReader(self.video_path[index], width=self.width, height=self.height)
        sample_index = list(range(self.sample_start_idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames]
        video = vr.get_batch(sample_index)
        video = rearrange(video, "f h w c -> f c h w")

        example = {
            "pixel_values": (video / 127.5 - 1.0),
            "prompt_ids": self.prompt_ids[index]
        }

        return example