|
from typing import Callable, List, Optional, Union |
|
from torch.utils.data import Dataset |
|
|
|
import decord |
|
decord.bridge.set_bridge('torch') |
|
|
|
from einops import rearrange |
|
|
|
import random |
|
import os |
|
import json |
|
from PIL import Image, ImageFilter |
|
import numpy as np |
|
import cv2 |
|
from scipy import ndimage |
|
import tempfile |
|
import ffmpeg |
|
|
|
from transformers import CLIPTokenizer |
|
|
|
|
|
class FramesDataset(Dataset): |
|
def __init__( |
|
self, |
|
samples_dir: str, |
|
prompt_map_path: Union[str, list[str]], |
|
width: int = 512, |
|
height: int = 512, |
|
video_length: int = 16, |
|
sample_start_index: int = 0, |
|
sample_count: int = 1, |
|
sample_frame_rate: int = 8, |
|
variance_threshold: int = 50, |
|
tokenizer: CLIPTokenizer = None, |
|
): |
|
|
|
print("FramesDataset", "init", width, height, video_length, sample_count) |
|
|
|
self.width = width |
|
self.height = height |
|
self.video_length = video_length |
|
self.sample_count = sample_count |
|
self.tokenizer = tokenizer |
|
self.samples_dir = samples_dir |
|
self.sample_start_index = sample_start_index |
|
self.sample_frame_rate = sample_frame_rate |
|
self.variance_threshold = variance_threshold |
|
|
|
self.samples = [] |
|
|
|
self.prompt_map = None |
|
|
|
with open(prompt_map_path, 'r') as f: |
|
self.prompt_map = json.loads(f.read()) |
|
|
|
self.frames_path = [str(k) for k in self.prompt_map.keys()] |
|
|
|
print("FramesDataset", "init", "frames_path", len(self.frames_path)) |
|
|
|
def load(self): |
|
print("FramesDataset", "load", "samples_dir", self.samples_dir) |
|
|
|
def extract_integer(filename): |
|
return int(filename.split('.')[0]) |
|
|
|
self.samples = [] |
|
files = sorted(os.listdir(self.samples_dir), key=extract_integer) |
|
for filename in files: |
|
if 'json' in filename: |
|
with open(f"{self.samples_dir}/{filename}", 'r') as f: |
|
sample = json.loads(f.read()) |
|
sample['prompt_ids'] = self.tokenize(sample['prompt']) |
|
self.samples.append(sample) |
|
print("FramesDataset", "load", "samples", len(self.samples)) |
|
|
|
def tokenize(self, prompt): |
|
input_ids = self.tokenizer( |
|
prompt, |
|
max_length=self.tokenizer.model_max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt" |
|
).input_ids[0] |
|
|
|
return input_ids |
|
|
|
def prepare(self): |
|
print("FramesDataset", "prepare") |
|
|
|
candidates = [] |
|
for dir_path in self.frames_path: |
|
candidates = candidates + self.load_key_frames(dir_path) |
|
|
|
print("FramesDataset", "prepare", "candidates", len(candidates)) |
|
|
|
self.pick(self.sample_count, candidates) |
|
|
|
def pick(self, count, candidates): |
|
print("FramesDataset", "pick", count, len(candidates)) |
|
|
|
sample_index = self.sample_start_index |
|
while True: |
|
key_frame = random.choice(candidates) |
|
print("FramesDataset", "pick", "key_frame", key_frame) |
|
|
|
dir_name = os.path.dirname(key_frame) |
|
file_name = os.path.basename(key_frame) |
|
frame_number = int(file_name.split(".")[0]) |
|
|
|
sample = [] |
|
for i in range(frame_number, frame_number + self.video_length): |
|
frame_path = f"{dir_name}/{i}.png" |
|
frame = Image.open(frame_path) |
|
frame = frame.resize((self.width, self.height)) |
|
sample.append(np.array(frame)) |
|
|
|
sample = np.array(sample) |
|
|
|
print("FramesDataset", "pick", "reading sample", sample.shape) |
|
|
|
if not self.check(sample): |
|
print("FramesDataset", "pick", "skip") |
|
continue |
|
|
|
print("FramesDataset", "pick", "checked") |
|
|
|
prompt = self.get_prompt(key_frame) |
|
|
|
sample_file = f"{self.samples_dir}/{sample_index}.mp4" |
|
self.write_video(sample, sample_file, self.sample_frame_rate) |
|
print("FramesDataset", "pick", "sample_file", sample_file) |
|
|
|
meta_file = f"{self.samples_dir}/{sample_index}.json" |
|
with open(meta_file, 'w') as f: |
|
f.write(json.dumps({ |
|
'key_frame': key_frame, |
|
'video_file': sample_file, |
|
'prompt': prompt, |
|
})) |
|
print("FramesDataset", "pick", "meta_file", meta_file) |
|
|
|
sample_index = sample_index + 1 |
|
if sample_index == self.sample_start_index + self.sample_count: |
|
print("FramesDataset", "pick", "done") |
|
break |
|
|
|
def write_video(self, frames, video_file, video_fps): |
|
with tempfile.TemporaryDirectory() as frames_dir: |
|
for index, frame in enumerate(frames): |
|
Image.fromarray(frame).save(f"{frames_dir}/{index}.png") |
|
|
|
(ffmpeg |
|
.input(f"{frames_dir}/%d.png") |
|
.output(video_file, vcodec='libx264', vf=f"fps={video_fps}") |
|
.overwrite_output() |
|
.run()) |
|
|
|
def get_prompt(self, key_frame): |
|
print("FramesDataset", "get_prompt", key_frame) |
|
|
|
dir_name = os.path.dirname(key_frame) |
|
file_name = os.path.basename(key_frame) |
|
number = int(file_name.split(".")[0]) |
|
prompt = "" |
|
if dir_name in self.prompt_map: |
|
prompt_map = self.prompt_map[dir_name] |
|
for k in prompt_map: |
|
if number >= int(k): |
|
print("FramesDataset", "get_prompt", k, prompt_map[k]) |
|
return prompt_map[k] |
|
|
|
print("FramesDataset", "get_prompt", "not found") |
|
return prompt |
|
|
|
def check(self, sample): |
|
diffs = [] |
|
for i in range(0, len(sample)-1): |
|
diffs.append(np.sum(self.blur(sample[i]) - self.blur(sample[i-1]))) |
|
|
|
first_diff = diffs[0] |
|
variance = np.var(diffs)**(1/2)/first_diff * 100 |
|
threshold = self.variance_threshold |
|
|
|
return variance < threshold |
|
|
|
def blur(self, frame): |
|
image = Image.fromarray(frame) |
|
image = image.filter(ImageFilter.GaussianBlur(radius=5)) |
|
return np.array(image) |
|
|
|
def load_key_frames(self, dir_path): |
|
print("FramesDataset", "load_key_frames", dir_path) |
|
|
|
if not os.path.isdir(dir_path): |
|
raise Exception("Dir not exist") |
|
|
|
def extract_integer(filename): |
|
return int(filename.split('.')[0]) |
|
|
|
candidates = [] |
|
|
|
files = sorted(os.listdir(dir_path), key=extract_integer) |
|
print("FramesDataset", "load_key_frames", "files", len(files)) |
|
count = len(files) |
|
for index, file_name in enumerate(files): |
|
file_path = f"{dir_path}/{file_name}" |
|
|
|
if 'png' in file_name and index + self.video_length <= count: |
|
candidates.append(file_path) |
|
|
|
print("FramesDataset", "load_key_frames", "candidates", len(candidates)) |
|
return candidates |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def __getitem__(self, index): |
|
meta = self.samples[index] |
|
vr = decord.VideoReader(meta['video_file']) |
|
sample_index = list(range(0, len(vr)))[:self.video_length] |
|
video = vr.get_batch(sample_index) |
|
video = rearrange(video, "f h w c -> f c h w") |
|
meta['pixel_values'] = (video / 127.5 - 1.0) |
|
return meta |
|
|
|
if __name__ == "__main__": |
|
|
|
tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer") |
|
|
|
dataset = FramesDataset( |
|
samples_dir = "test/FramesDataset/samples_dir", |
|
prompt_map_path = 'test/FramesDataset/prompt_map.json', |
|
width = 512, |
|
height = 512, |
|
video_length = 16, |
|
sample_count = 1, |
|
tokenizer = tokenizer, |
|
variance_threshold = 40, |
|
) |
|
|
|
dataset.prepare() |
|
|
|
|
|
|
|
|