import decord decord.bridge.set_bridge('torch') from torch.utils.data import Dataset from einops import rearrange import os from PIL import Image import numpy as np class TuneAVideoDataset(Dataset): def __init__( self, video_path: str, prompt: str, width: int = 512, height: int = 512, n_sample_frames: int = 8, sample_start_idx: int = 0, sample_frame_rate: int = 1, ): self.video_path = video_path self.prompt = prompt self.prompt_ids = None self.uncond_prompt_ids = None self.width = width self.height = height self.n_sample_frames = n_sample_frames self.sample_start_idx = sample_start_idx self.sample_frame_rate = sample_frame_rate if 'mp4' not in self.video_path: self.images = [] for file in sorted(os.listdir(self.video_path), key=lambda x: int(x[:-4])): if file.endswith('jpg'): self.images.append(np.asarray(Image.open(os.path.join(self.video_path, file)).convert('RGB').resize((self.width, self.height)))) self.images = np.stack(self.images) def __len__(self): return 1 def __getitem__(self, index): # load and sample video frames if 'mp4' in self.video_path: vr = decord.VideoReader(self.video_path, width=self.width, height=self.height) sample_index = list(range(self.sample_start_idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames] video = vr.get_batch(sample_index) else: video = self.images[:self.n_sample_frames] video = rearrange(video, "f h w c -> f c h w") example = { "pixel_values": (video / 127.5 - 1.0), "prompt_ids": self.prompt_ids, } return example