Spaces:
Running
on
Zero
Running
on
Zero
from utils.dataset_utils import * | |
class VideoFolderDataset(Dataset): | |
def __init__( | |
self, | |
tokenizer=None, | |
width: int = 256, | |
height: int = 256, | |
n_sample_frames: int = 16, | |
fps: int = 8, | |
path: str = "./data", | |
fallback_prompt: str = "", | |
use_bucketing: bool = False, | |
**kwargs | |
): | |
self.tokenizer = tokenizer | |
self.use_bucketing = use_bucketing | |
self.fallback_prompt = fallback_prompt | |
self.video_files = glob(f"{path}/*.mp4") | |
self.width = width | |
self.height = height | |
self.n_sample_frames = n_sample_frames | |
self.fps = fps | |
def get_frame_buckets(self, vr): | |
h, w, c = vr[0].shape | |
width, height = sensible_buckets(self.width, self.height, w, h) | |
resize = T.transforms.Resize((height, width), antialias=True) | |
return resize | |
def get_frame_batch(self, vr, resize=None): | |
n_sample_frames = self.n_sample_frames | |
native_fps = vr.get_avg_fps() | |
every_nth_frame = max(1, round(native_fps / self.fps)) | |
every_nth_frame = min(len(vr), every_nth_frame) | |
effective_length = len(vr) // every_nth_frame | |
if effective_length < n_sample_frames: | |
n_sample_frames = effective_length | |
effective_idx = random.randint(0, (effective_length - n_sample_frames)) | |
idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames) | |
video = vr.get_batch(idxs) | |
video = rearrange(video, "f h w c -> f c h w") | |
if resize is not None: video = resize(video) | |
return video, vr | |
def process_video_wrapper(self, vid_path): | |
video, vr = process_video( | |
vid_path, | |
self.use_bucketing, | |
self.width, | |
self.height, | |
self.get_frame_buckets, | |
self.get_frame_batch | |
) | |
return video, vr | |
def get_prompt_ids(self, prompt): | |
return self.tokenizer( | |
prompt, | |
truncation=True, | |
padding="max_length", | |
max_length=self.tokenizer.model_max_length, | |
return_tensors="pt", | |
).input_ids | |
def __getname__(): return 'folder' | |
def __len__(self): | |
return len(self.video_files) | |
def __getitem__(self, index): | |
video, _ = self.process_video_wrapper(self.video_files[index]) | |
prompt = self.fallback_prompt | |
prompt_ids = self.get_prompt_ids(prompt) | |
return {"pixel_values": (video[0] / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, 'dataset': self.__getname__()} |