from typing import Callable, List, Optional, Union |
from torch.utils.data import Dataset |
import decord |
decord.bridge.set_bridge('torch') |
from einops import rearrange |
import random |
import os |
import json |
from PIL import Image, ImageFilter |
import numpy as np |
import cv2 |
from scipy import ndimage |
import tempfile |
import ffmpeg |
from transformers import CLIPTokenizer |
class FramesDataset(Dataset): |
def __init__( |
self, |
samples_dir: str, |
prompt_map_path: Union[str, list[str]], |
width: int = 512, |
height: int = 512, |
video_length: int = 16, |
sample_start_index: int = 0, |
sample_count: int = 1, |
sample_frame_rate: int = 8, |
variance_threshold: int = 50, |
tokenizer: CLIPTokenizer = None, |
): |
print("FramesDataset", "init", width, height, video_length, sample_count) |
self.width = width |
self.height = height |
self.video_length = video_length |
self.sample_count = sample_count |
self.tokenizer = tokenizer |
self.samples_dir = samples_dir |
self.sample_start_index = sample_start_index |
self.sample_frame_rate = sample_frame_rate |
self.variance_threshold = variance_threshold |
self.samples = [] |
self.prompt_map = None |
with open(prompt_map_path, 'r') as f: |
self.prompt_map = json.loads(f.read()) |
self.frames_path = [str(k) for k in self.prompt_map.keys()] |
print("FramesDataset", "init", "frames_path", len(self.frames_path)) |
def load(self): |
print("FramesDataset", "load", "samples_dir", self.samples_dir) |
def extract_integer(filename): |
return int(filename.split('.')[0]) |
self.samples = [] |
files = sorted(os.listdir(self.samples_dir), key=extract_integer) |
for filename in files: |
if 'json' in filename: |
with open(f"{self.samples_dir}/{filename}", 'r') as f: |
sample = json.loads(f.read()) |
sample['prompt_ids'] = self.tokenize(sample['prompt']) |
self.samples.append(sample) |
print("FramesDataset", "load", "samples", len(self.samples)) |
def tokenize(self, prompt): |
input_ids = self.tokenizer( |
prompt, |
max_length=self.tokenizer.model_max_length, |
padding="max_length", |
truncation=True, |
return_tensors="pt" |
).input_ids[0] |
return input_ids |
def prepare(self): |
print("FramesDataset", "prepare") |
candidates = [] |
for dir_path in self.frames_path: |
candidates = candidates + self.load_key_frames(dir_path) |
print("FramesDataset", "prepare", "candidates", len(candidates)) |
self.pick(self.sample_count, candidates) |
def pick(self, count, candidates): |
print("FramesDataset", "pick", count, len(candidates)) |
sample_index = self.sample_start_index |
while True: |
key_frame = random.choice(candidates) |
print("FramesDataset", "pick", "key_frame", key_frame) |
dir_name = os.path.dirname(key_frame) |
file_name = os.path.basename(key_frame) |
frame_number = int(file_name.split(".")[0]) |
sample = [] |
for i in range(frame_number, frame_number + self.video_length): |
frame_path = f"{dir_name}/{i}.png" |
frame = Image.open(frame_path) |
frame = frame.resize((self.width, self.height)) |
sample.append(np.array(frame)) |
sample = np.array(sample) |
print("FramesDataset", "pick", "reading sample", sample.shape) |
if not self.check(sample): |
print("FramesDataset", "pick", "skip") |
continue |
print("FramesDataset", "pick", "checked") |
prompt = self.get_prompt(key_frame) |
sample_file = f"{self.samples_dir}/{sample_index}.mp4" |
self.write_video(sample, sample_file, self.sample_frame_rate) |
print("FramesDataset", "pick", "sample_file", sample_file) |
meta_file = f"{self.samples_dir}/{sample_index}.json" |
with open(meta_file, 'w') as f: |
f.write(json.dumps({ |
'key_frame': key_frame, |
'video_file': sample_file, |
'prompt': prompt, |
})) |
print("FramesDataset", "pick", "meta_file", meta_file) |
sample_index = sample_index + 1 |
if sample_index == self.sample_start_index + self.sample_count: |
print("FramesDataset", "pick", "done") |
break |
def write_video(self, frames, video_file, video_fps): |
with tempfile.TemporaryDirectory() as frames_dir: |
for index, frame in enumerate(frames): |
Image.fromarray(frame).save(f"{frames_dir}/{index}.png") |
(ffmpeg |
.input(f"{frames_dir}/%d.png") |
.output(video_file, vcodec='libx264', vf=f"fps={video_fps}") |
.overwrite_output() |
.run()) |
def get_prompt(self, key_frame): |
print("FramesDataset", "get_prompt", key_frame) |
dir_name = os.path.dirname(key_frame) |
file_name = os.path.basename(key_frame) |
number = int(file_name.split(".")[0]) |
prompt = "" |
if dir_name in self.prompt_map: |
prompt_map = self.prompt_map[dir_name] |
for k in prompt_map: |
if number >= int(k): |
print("FramesDataset", "get_prompt", k, prompt_map[k]) |
return prompt_map[k] |
print("FramesDataset", "get_prompt", "not found") |
return prompt |
def check(self, sample): |
diffs = [] |
for i in range(0, len(sample)-1): |
diffs.append(np.sum(self.blur(sample[i]) - self.blur(sample[i-1]))) |
first_diff = diffs[0] |
variance = np.var(diffs)**(1/2)/first_diff * 100 |
threshold = self.variance_threshold |
return variance < threshold |
def blur(self, frame): |
image = Image.fromarray(frame) |
image = image.filter(ImageFilter.GaussianBlur(radius=5)) |
return np.array(image) |
def load_key_frames(self, dir_path): |
print("FramesDataset", "load_key_frames", dir_path) |
if not os.path.isdir(dir_path): |
raise Exception("Dir not exist") |
def extract_integer(filename): |
return int(filename.split('.')[0]) |
candidates = [] |
files = sorted(os.listdir(dir_path), key=extract_integer) |
print("FramesDataset", "load_key_frames", "files", len(files)) |
count = len(files) |
for index, file_name in enumerate(files): |
file_path = f"{dir_path}/{file_name}" |
if 'png' in file_name and index + self.video_length <= count: |
candidates.append(file_path) |
print("FramesDataset", "load_key_frames", "candidates", len(candidates)) |
return candidates |
def __len__(self): |
return len(self.samples) |
def __getitem__(self, index): |
meta = self.samples[index] |
vr = decord.VideoReader(meta['video_file']) |
sample_index = list(range(0, len(vr)))[:self.video_length] |
video = vr.get_batch(sample_index) |
video = rearrange(video, "f h w c -> f c h w") |
meta['pixel_values'] = (video / 127.5 - 1.0) |
return meta |
if __name__ == "__main__": |
tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer") |
dataset = FramesDataset( |
samples_dir = "test/FramesDataset/samples_dir", |
prompt_map_path = 'test/FramesDataset/prompt_map.json', |
width = 512, |
height = 512, |
video_length = 16, |
sample_count = 1, |
tokenizer = tokenizer, |
variance_threshold = 40, |
) |
dataset.prepare() |