MotionInversion / dataset /video_folder_dataset.py
ziyangmai's picture
page demo
113884e
from utils.dataset_utils import *
class VideoFolderDataset(Dataset):
def __init__(
self,
tokenizer=None,
width: int = 256,
height: int = 256,
n_sample_frames: int = 16,
fps: int = 8,
path: str = "./data",
fallback_prompt: str = "",
use_bucketing: bool = False,
**kwargs
):
self.tokenizer = tokenizer
self.use_bucketing = use_bucketing
self.fallback_prompt = fallback_prompt
self.video_files = glob(f"{path}/*.mp4")
self.width = width
self.height = height
self.n_sample_frames = n_sample_frames
self.fps = fps
def get_frame_buckets(self, vr):
h, w, c = vr[0].shape
width, height = sensible_buckets(self.width, self.height, w, h)
resize = T.transforms.Resize((height, width), antialias=True)
return resize
def get_frame_batch(self, vr, resize=None):
n_sample_frames = self.n_sample_frames
native_fps = vr.get_avg_fps()
every_nth_frame = max(1, round(native_fps / self.fps))
every_nth_frame = min(len(vr), every_nth_frame)
effective_length = len(vr) // every_nth_frame
if effective_length < n_sample_frames:
n_sample_frames = effective_length
effective_idx = random.randint(0, (effective_length - n_sample_frames))
idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames)
video = vr.get_batch(idxs)
video = rearrange(video, "f h w c -> f c h w")
if resize is not None: video = resize(video)
return video, vr
def process_video_wrapper(self, vid_path):
video, vr = process_video(
vid_path,
self.use_bucketing,
self.width,
self.height,
self.get_frame_buckets,
self.get_frame_batch
)
return video, vr
def get_prompt_ids(self, prompt):
return self.tokenizer(
prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
@staticmethod
def __getname__(): return 'folder'
def __len__(self):
return len(self.video_files)
def __getitem__(self, index):
video, _ = self.process_video_wrapper(self.video_files[index])
prompt = self.fallback_prompt
prompt_ids = self.get_prompt_ids(prompt)
return {"pixel_values": (video[0] / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, 'dataset': self.__getname__()}