from utils.dataset_utils import * class VideoFolderDataset(Dataset): def __init__( self, tokenizer=None, width: int = 256, height: int = 256, n_sample_frames: int = 16, fps: int = 8, path: str = "./data", fallback_prompt: str = "", use_bucketing: bool = False, **kwargs ): self.tokenizer = tokenizer self.use_bucketing = use_bucketing self.fallback_prompt = fallback_prompt self.video_files = glob(f"{path}/*.mp4") self.width = width self.height = height self.n_sample_frames = n_sample_frames self.fps = fps def get_frame_buckets(self, vr): h, w, c = vr[0].shape width, height = sensible_buckets(self.width, self.height, w, h) resize = T.transforms.Resize((height, width), antialias=True) return resize def get_frame_batch(self, vr, resize=None): n_sample_frames = self.n_sample_frames native_fps = vr.get_avg_fps() every_nth_frame = max(1, round(native_fps / self.fps)) every_nth_frame = min(len(vr), every_nth_frame) effective_length = len(vr) // every_nth_frame if effective_length < n_sample_frames: n_sample_frames = effective_length effective_idx = random.randint(0, (effective_length - n_sample_frames)) idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames) video = vr.get_batch(idxs) video = rearrange(video, "f h w c -> f c h w") if resize is not None: video = resize(video) return video, vr def process_video_wrapper(self, vid_path): video, vr = process_video( vid_path, self.use_bucketing, self.width, self.height, self.get_frame_buckets, self.get_frame_batch ) return video, vr def get_prompt_ids(self, prompt): return self.tokenizer( prompt, truncation=True, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt", ).input_ids @staticmethod def __getname__(): return 'folder' def __len__(self): return len(self.video_files) def __getitem__(self, index): video, _ = self.process_video_wrapper(self.video_files[index]) prompt = self.fallback_prompt prompt_ids = self.get_prompt_ids(prompt) return {"pixel_values": (video[0] / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, 'dataset': self.__getname__()}