Text2video2024 / dataset.py
Tello2020's picture
Upload 14 files
fd5f698
import os
import decord
import numpy as np
import random
import json
import torchvision
import torchvision.transforms as T
import torch
from glob import glob
from PIL import Image
from itertools import islice
from pathlib import Path
from .bucketing import sensible_buckets
decord.bridge.set_bridge('torch')
from torch.utils.data import Dataset
from einops import rearrange, repeat
def get_prompt_ids(prompt, tokenizer):
prompt_ids = tokenizer(
prompt,
truncation=True,
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids
return prompt_ids
def read_caption_file(caption_file):
with open(caption_file, 'r', encoding="utf8") as t:
return t.read()
def get_text_prompt(
text_prompt: str = '',
fallback_prompt: str= '',
file_path:str = '',
ext_types=['.mp4'],
use_caption=False
):
try:
if use_caption:
if len(text_prompt) > 1: return text_prompt
caption_file = ''
# Use caption on per-video basis (One caption PER video)
for ext in ext_types:
maybe_file = file_path.replace(ext, '.txt')
if maybe_file.endswith(ext_types): continue
if os.path.exists(maybe_file):
caption_file = maybe_file
break
if os.path.exists(caption_file):
return read_caption_file(caption_file)
# Return fallback prompt if no conditions are met.
return fallback_prompt
return text_prompt
except:
print(f"Couldn't read prompt caption for {file_path}. Using fallback.")
return fallback_prompt
def get_video_frames(vr, start_idx, sample_rate=1, max_frames=24):
max_range = len(vr)
frame_number = sorted((0, start_idx, max_range))[1]
frame_range = range(frame_number, max_range, sample_rate)
frame_range_indices = list(frame_range)[:max_frames]
return frame_range_indices
def process_video(vid_path, use_bucketing, w, h, get_frame_buckets, get_frame_batch):
if use_bucketing:
vr = decord.VideoReader(vid_path)
resize = get_frame_buckets(vr)
video = get_frame_batch(vr, resize=resize)
else:
vr = decord.VideoReader(vid_path, width=w, height=h)
video = get_frame_batch(vr)
return video, vr
# https://github.com/ExponentialML/Video-BLIP2-Preprocessor
class VideoJsonDataset(Dataset):
def __init__(
self,
tokenizer = None,
width: int = 256,
height: int = 256,
n_sample_frames: int = 4,
sample_start_idx: int = 1,
frame_step: int = 1,
json_path: str ="",
json_data = None,
vid_data_key: str = "video_path",
preprocessed: bool = False,
use_bucketing: bool = False,
**kwargs
):
self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg")
self.use_bucketing = use_bucketing
self.tokenizer = tokenizer
self.preprocessed = preprocessed
self.vid_data_key = vid_data_key
self.train_data = self.load_from_json(json_path, json_data)
self.width = width
self.height = height
self.n_sample_frames = n_sample_frames
self.sample_start_idx = sample_start_idx
self.frame_step = frame_step
def build_json(self, json_data):
extended_data = []
for data in json_data['data']:
for nested_data in data['data']:
self.build_json_dict(
data,
nested_data,
extended_data
)
json_data = extended_data
return json_data
def build_json_dict(self, data, nested_data, extended_data):
clip_path = nested_data['clip_path'] if 'clip_path' in nested_data else None
extended_data.append({
self.vid_data_key: data[self.vid_data_key],
'frame_index': nested_data['frame_index'],
'prompt': nested_data['prompt'],
'clip_path': clip_path
})
def load_from_json(self, path, json_data):
try:
with open(path) as jpath:
print(f"Loading JSON from {path}")
json_data = json.load(jpath)
return self.build_json(json_data)
except:
self.train_data = []
print("Non-existant JSON path. Skipping.")
def validate_json(self, base_path, path):
return os.path.exists(f"{base_path}/{path}")
def get_frame_range(self, vr):
return get_video_frames(
vr,
self.sample_start_idx,
self.frame_step,
self.n_sample_frames
)
def get_vid_idx(self, vr, vid_data=None):
frames = self.n_sample_frames
if vid_data is not None:
idx = vid_data['frame_index']
else:
idx = self.sample_start_idx
return idx
def get_frame_buckets(self, vr):
_, h, w = vr[0].shape
width, height = sensible_buckets(self.width, self.height, h, w)
resize = T.transforms.Resize((height, width), antialias=True)
return resize
def get_frame_batch(self, vr, resize=None):
frame_range = self.get_frame_range(vr)
frames = vr.get_batch(frame_range)
video = rearrange(frames, "f h w c -> f c h w")
if resize is not None: video = resize(video)
return video
def process_video_wrapper(self, vid_path):
video, vr = process_video(
vid_path,
self.use_bucketing,
self.width,
self.height,
self.get_frame_buckets,
self.get_frame_batch
)
return video, vr
def train_data_batch(self, index):
# If we are training on individual clips.
if 'clip_path' in self.train_data[index] and \
self.train_data[index]['clip_path'] is not None:
vid_data = self.train_data[index]
clip_path = vid_data['clip_path']
# Get video prompt
prompt = vid_data['prompt']
video, _ = self.process_video_wrapper(clip_path)
prompt_ids = prompt_ids = get_prompt_ids(prompt, self.tokenizer)
return video, prompt, prompt_ids
# Assign train data
train_data = self.train_data[index]
# Get the frame of the current index.
self.sample_start_idx = train_data['frame_index']
# Initialize resize
resize = None
video, vr = self.process_video_wrapper(train_data[self.vid_data_key])
# Get video prompt
prompt = train_data['prompt']
vr.seek(0)
prompt_ids = get_prompt_ids(prompt, self.tokenizer)
return video, prompt, prompt_ids
@staticmethod
def __getname__(): return 'json'
def __len__(self):
if self.train_data is not None:
return len(self.train_data)
else:
return 0
def __getitem__(self, index):
# Initialize variables
video = None
prompt = None
prompt_ids = None
# Use default JSON training
if self.train_data is not None:
video, prompt, prompt_ids = self.train_data_batch(index)
example = {
"pixel_values": (video / 127.5 - 1.0),
"prompt_ids": prompt_ids[0],
"text_prompt": prompt,
'dataset': self.__getname__()
}
return example
class SingleVideoDataset(Dataset):
def __init__(
self,
tokenizer = None,
width: int = 256,
height: int = 256,
n_sample_frames: int = 4,
frame_step: int = 1,
single_video_path: str = "",
single_video_prompt: str = "",
use_caption: bool = False,
use_bucketing: bool = False,
**kwargs
):
self.tokenizer = tokenizer
self.use_bucketing = use_bucketing
self.frames = []
self.index = 1
self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg")
self.n_sample_frames = n_sample_frames
self.frame_step = frame_step
self.single_video_path = single_video_path
self.single_video_prompt = single_video_prompt
self.width = width
self.height = height
def create_video_chunks(self):
# Create a list of frames separated by sample frames
# [(1,2,3), (4,5,6), ...]
vr = decord.VideoReader(self.single_video_path)
vr_range = range(1, len(vr), self.frame_step)
self.frames = list(self.chunk(vr_range, self.n_sample_frames))
# Delete any list that contains an out of range index.
for i, inner_frame_nums in enumerate(self.frames):
for frame_num in inner_frame_nums:
if frame_num > len(vr):
print(f"Removing out of range index list at position: {i}...")
del self.frames[i]
return self.frames
def chunk(self, it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def get_frame_batch(self, vr, resize=None):
index = self.index
frames = vr.get_batch(self.frames[self.index])
video = rearrange(frames, "f h w c -> f c h w")
if resize is not None: video = resize(video)
return video
def get_frame_buckets(self, vr):
_, h, w = vr[0].shape
width, height = sensible_buckets(self.width, self.height, h, w)
resize = T.transforms.Resize((height, width), antialias=True)
return resize
def process_video_wrapper(self, vid_path):
video, vr = process_video(
vid_path,
self.use_bucketing,
self.width,
self.height,
self.get_frame_buckets,
self.get_frame_batch
)
return video, vr
def single_video_batch(self, index):
train_data = self.single_video_path
self.index = index
if train_data.endswith(self.vid_types):
video, _ = self.process_video_wrapper(train_data)
prompt = self.single_video_prompt
prompt_ids = get_prompt_ids(prompt, self.tokenizer)
return video, prompt, prompt_ids
else:
raise ValueError(f"Single video is not a video type. Types: {self.vid_types}")
@staticmethod
def __getname__(): return 'single_video'
def __len__(self):
return len(self.create_video_chunks())
def __getitem__(self, index):
video, prompt, prompt_ids = self.single_video_batch(index)
example = {
"pixel_values": (video / 127.5 - 1.0),
"prompt_ids": prompt_ids[0],
"text_prompt": prompt,
'dataset': self.__getname__()
}
return example
class ImageDataset(Dataset):
def __init__(
self,
tokenizer = None,
width: int = 256,
height: int = 256,
base_width: int = 256,
base_height: int = 256,
use_caption: bool = False,
image_dir: str = '',
single_img_prompt: str = '',
use_bucketing: bool = False,
fallback_prompt: str = '',
**kwargs
):
self.tokenizer = tokenizer
self.img_types = (".png", ".jpg", ".jpeg", '.bmp')
self.use_bucketing = use_bucketing
self.image_dir = self.get_images_list(image_dir)
self.fallback_prompt = fallback_prompt
self.use_caption = use_caption
self.single_img_prompt = single_img_prompt
self.width = width
self.height = height
def get_images_list(self, image_dir):
if os.path.exists(image_dir):
imgs = [x for x in os.listdir(image_dir) if x.endswith(self.img_types)]
full_img_dir = []
for img in imgs:
full_img_dir.append(f"{image_dir}/{img}")
return sorted(full_img_dir)
return ['']
def image_batch(self, index):
train_data = self.image_dir[index]
img = train_data
try:
img = torchvision.io.read_image(img, mode=torchvision.io.ImageReadMode.RGB)
except:
img = T.transforms.PILToTensor()(Image.open(img).convert("RGB"))
width = self.width
height = self.height
if self.use_bucketing:
_, h, w = img.shape
width, height = sensible_buckets(width, height, w, h)
resize = T.transforms.Resize((height, width), antialias=True)
img = resize(img)
img = repeat(img, 'c h w -> f c h w', f=1)
prompt = get_text_prompt(
file_path=train_data,
text_prompt=self.single_img_prompt,
fallback_prompt=self.fallback_prompt,
ext_types=self.img_types,
use_caption=True
)
prompt_ids = get_prompt_ids(prompt, self.tokenizer)
return img, prompt, prompt_ids
@staticmethod
def __getname__(): return 'image'
def __len__(self):
# Image directory
if os.path.exists(self.image_dir[0]):
return len(self.image_dir)
else:
return 0
def __getitem__(self, index):
img, prompt, prompt_ids = self.image_batch(index)
example = {
"pixel_values": (img / 127.5 - 1.0),
"prompt_ids": prompt_ids[0],
"text_prompt": prompt,
'dataset': self.__getname__()
}
return example
class VideoFolderDataset(Dataset):
def __init__(
self,
tokenizer=None,
width: int = 256,
height: int = 256,
n_sample_frames: int = 16,
fps: int = 8,
path: str = "./data",
fallback_prompt: str = "",
use_bucketing: bool = False,
**kwargs
):
self.tokenizer = tokenizer
self.use_bucketing = use_bucketing
self.fallback_prompt = fallback_prompt
self.video_files = glob(f"{path}/*.mp4")
self.width = width
self.height = height
self.n_sample_frames = n_sample_frames
self.fps = fps
def get_frame_buckets(self, vr):
_, h, w = vr[0].shape
width, height = sensible_buckets(self.width, self.height, h, w)
resize = T.transforms.Resize((height, width), antialias=True)
return resize
def get_frame_batch(self, vr, resize=None):
n_sample_frames = self.n_sample_frames
native_fps = vr.get_avg_fps()
every_nth_frame = max(1, round(native_fps / self.fps))
every_nth_frame = min(len(vr), every_nth_frame)
effective_length = len(vr) // every_nth_frame
if effective_length < n_sample_frames:
n_sample_frames = effective_length
effective_idx = random.randint(0, (effective_length - n_sample_frames))
idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames)
video = vr.get_batch(idxs)
video = rearrange(video, "f h w c -> f c h w")
if resize is not None: video = resize(video)
return video, vr
def process_video_wrapper(self, vid_path):
video, vr = process_video(
vid_path,
self.use_bucketing,
self.width,
self.height,
self.get_frame_buckets,
self.get_frame_batch
)
return video, vr
def get_prompt_ids(self, prompt):
return self.tokenizer(
prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
@staticmethod
def __getname__(): return 'folder'
def __len__(self):
return len(self.video_files)
def __getitem__(self, index):
video, _ = self.process_video_wrapper(self.video_files[index])
if os.path.exists(self.video_files[index].replace(".mp4", ".txt")):
with open(self.video_files[index].replace(".mp4", ".txt"), "r") as f:
prompt = f.read()
else:
prompt = self.fallback_prompt
prompt_ids = self.get_prompt_ids(prompt)
return {"pixel_values": (video[0] / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, 'dataset': self.__getname__()}
class CachedDataset(Dataset):
def __init__(self,cache_dir: str = ''):
self.cache_dir = cache_dir
self.cached_data_list = self.get_files_list()
def get_files_list(self):
tensors_list = [f"{self.cache_dir}/{x}" for x in os.listdir(self.cache_dir) if x.endswith('.pt')]
return sorted(tensors_list)
def __len__(self):
return len(self.cached_data_list)
def __getitem__(self, index):
cached_latent = torch.load(self.cached_data_list[index], map_location='cuda:0')
return cached_latent