bluestarburst's picture
Upload folder using huggingface_hub
09bf9a3
raw
history blame
8.06 kB
from typing import Callable, List, Optional, Union
from torch.utils.data import Dataset
import decord
decord.bridge.set_bridge('torch')
from einops import rearrange
import random
import os
import json
from PIL import Image, ImageFilter
import numpy as np
import cv2
from scipy import ndimage
import tempfile
import ffmpeg
from transformers import CLIPTokenizer
class FramesDataset(Dataset):
def __init__(
self,
samples_dir: str,
prompt_map_path: Union[str, list[str]],
width: int = 512,
height: int = 512,
video_length: int = 16,
sample_start_index: int = 0,
sample_count: int = 1,
sample_frame_rate: int = 8,
variance_threshold: int = 50,
tokenizer: CLIPTokenizer = None,
):
print("FramesDataset", "init", width, height, video_length, sample_count)
self.width = width
self.height = height
self.video_length = video_length
self.sample_count = sample_count
self.tokenizer = tokenizer
self.samples_dir = samples_dir
self.sample_start_index = sample_start_index
self.sample_frame_rate = sample_frame_rate
self.variance_threshold = variance_threshold
self.samples = []
self.prompt_map = None
with open(prompt_map_path, 'r') as f:
self.prompt_map = json.loads(f.read())
self.frames_path = [str(k) for k in self.prompt_map.keys()]
print("FramesDataset", "init", "frames_path", len(self.frames_path))
def load(self):
print("FramesDataset", "load", "samples_dir", self.samples_dir)
def extract_integer(filename):
return int(filename.split('.')[0])
self.samples = []
files = sorted(os.listdir(self.samples_dir), key=extract_integer)
for filename in files:
if 'json' in filename:
with open(f"{self.samples_dir}/{filename}", 'r') as f:
sample = json.loads(f.read())
sample['prompt_ids'] = self.tokenize(sample['prompt'])
self.samples.append(sample)
print("FramesDataset", "load", "samples", len(self.samples))
def tokenize(self, prompt):
input_ids = self.tokenizer(
prompt,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids[0]
return input_ids
def prepare(self):
print("FramesDataset", "prepare")
candidates = []
for dir_path in self.frames_path:
candidates = candidates + self.load_key_frames(dir_path)
print("FramesDataset", "prepare", "candidates", len(candidates))
self.pick(self.sample_count, candidates)
def pick(self, count, candidates):
print("FramesDataset", "pick", count, len(candidates))
sample_index = self.sample_start_index
while True:
key_frame = random.choice(candidates)
print("FramesDataset", "pick", "key_frame", key_frame)
dir_name = os.path.dirname(key_frame)
file_name = os.path.basename(key_frame)
frame_number = int(file_name.split(".")[0])
sample = []
for i in range(frame_number, frame_number + self.video_length):
frame_path = f"{dir_name}/{i}.png"
frame = Image.open(frame_path)
frame = frame.resize((self.width, self.height))
sample.append(np.array(frame))
sample = np.array(sample)
print("FramesDataset", "pick", "reading sample", sample.shape)
if not self.check(sample):
print("FramesDataset", "pick", "skip")
continue
print("FramesDataset", "pick", "checked")
prompt = self.get_prompt(key_frame)
sample_file = f"{self.samples_dir}/{sample_index}.mp4"
self.write_video(sample, sample_file, self.sample_frame_rate)
print("FramesDataset", "pick", "sample_file", sample_file)
meta_file = f"{self.samples_dir}/{sample_index}.json"
with open(meta_file, 'w') as f:
f.write(json.dumps({
'key_frame': key_frame,
'video_file': sample_file,
'prompt': prompt,
}))
print("FramesDataset", "pick", "meta_file", meta_file)
sample_index = sample_index + 1
if sample_index == self.sample_start_index + self.sample_count:
print("FramesDataset", "pick", "done")
break
def write_video(self, frames, video_file, video_fps):
with tempfile.TemporaryDirectory() as frames_dir:
for index, frame in enumerate(frames):
Image.fromarray(frame).save(f"{frames_dir}/{index}.png")
(ffmpeg
.input(f"{frames_dir}/%d.png")
.output(video_file, vcodec='libx264', vf=f"fps={video_fps}")
.overwrite_output()
.run())
def get_prompt(self, key_frame):
print("FramesDataset", "get_prompt", key_frame)
dir_name = os.path.dirname(key_frame)
file_name = os.path.basename(key_frame)
number = int(file_name.split(".")[0])
prompt = ""
if dir_name in self.prompt_map:
prompt_map = self.prompt_map[dir_name]
for k in prompt_map:
if number >= int(k):
print("FramesDataset", "get_prompt", k, prompt_map[k])
return prompt_map[k]
print("FramesDataset", "get_prompt", "not found")
return prompt
def check(self, sample):
diffs = []
for i in range(0, len(sample)-1):
diffs.append(np.sum(self.blur(sample[i]) - self.blur(sample[i-1])))
first_diff = diffs[0]
variance = np.var(diffs)**(1/2)/first_diff * 100
threshold = self.variance_threshold
return variance < threshold
def blur(self, frame):
image = Image.fromarray(frame)
image = image.filter(ImageFilter.GaussianBlur(radius=5))
return np.array(image)
def load_key_frames(self, dir_path):
print("FramesDataset", "load_key_frames", dir_path)
if not os.path.isdir(dir_path):
raise Exception("Dir not exist")
def extract_integer(filename):
return int(filename.split('.')[0])
candidates = []
files = sorted(os.listdir(dir_path), key=extract_integer)
print("FramesDataset", "load_key_frames", "files", len(files))
count = len(files)
for index, file_name in enumerate(files):
file_path = f"{dir_path}/{file_name}"
if 'png' in file_name and index + self.video_length <= count:
candidates.append(file_path)
print("FramesDataset", "load_key_frames", "candidates", len(candidates))
return candidates
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
meta = self.samples[index]
vr = decord.VideoReader(meta['video_file'])
sample_index = list(range(0, len(vr)))[:self.video_length]
video = vr.get_batch(sample_index)
video = rearrange(video, "f h w c -> f c h w")
meta['pixel_values'] = (video / 127.5 - 1.0)
return meta
if __name__ == "__main__":
tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
dataset = FramesDataset(
samples_dir = "test/FramesDataset/samples_dir",
prompt_map_path = 'test/FramesDataset/prompt_map.json',
width = 512,
height = 512,
video_length = 16,
sample_count = 1,
tokenizer = tokenizer,
variance_threshold = 40,
)
dataset.prepare()
#dataset.load()
#print(len(dataset), dataset[0]['key_frame'], dataset[0]['prompt'])