File size: 1,889 Bytes
69d3d9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import decord
decord.bridge.set_bridge('torch')

from torch.utils.data import Dataset
from einops import rearrange
import os
from PIL import Image
import numpy as np

class TuneAVideoDataset(Dataset):
    def __init__(
            self,
            video_path: str,
            prompt: str,
            width: int = 512,
            height: int = 512,
            n_sample_frames: int = 8,
            sample_start_idx: int = 0,
            sample_frame_rate: int = 1,
    ):
        self.video_path = video_path
        self.prompt = prompt
        self.prompt_ids = None
        self.uncond_prompt_ids = None

        self.width = width
        self.height = height
        self.n_sample_frames = n_sample_frames
        self.sample_start_idx = sample_start_idx
        self.sample_frame_rate = sample_frame_rate

        if 'mp4' not in self.video_path:
            self.images = []
            for file in sorted(os.listdir(self.video_path), key=lambda x: int(x[:-4])):
                if file.endswith('jpg'):
                    self.images.append(np.asarray(Image.open(os.path.join(self.video_path, file)).convert('RGB').resize((self.width, self.height))))
        self.images = np.stack(self.images)

    def __len__(self):
        return 1

    def __getitem__(self, index):
        # load and sample video frames
        if 'mp4' in self.video_path:
            vr = decord.VideoReader(self.video_path, width=self.width, height=self.height)
            sample_index = list(range(self.sample_start_idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames]
            video = vr.get_batch(sample_index)
        else:
            video = self.images[:self.n_sample_frames]
        video = rearrange(video, "f h w c -> f c h w")

        example = {
            "pixel_values": (video / 127.5 - 1.0),
            "prompt_ids": self.prompt_ids,
        }

        return example