PencilFolder / diffsynth /datasets /mvdataset.py
PencilHu's picture
Upload folder using huggingface_hub
1146a67 verified
import imageio, os, torch, warnings, torchvision, argparse, json
from peft import LoraConfig, inject_adapter_in_model
from PIL import Image
import pandas as pd
from tqdm import tqdm
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
import random
from decord import VideoReader
from decord import cpu, gpu
import imageio.v3 as iio
from torchvision import transforms
import torchvision
import random
import decord
from torchvision import transforms
import re
decord.bridge.set_bridge('torch')
import random
import numpy as np
from PIL import Image, ImageOps
class MulltiShot_MultiView_Dataset(torch.utils.data.Dataset):
def __init__(self, dataset_base_path='/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/merged_mark_paishe_ds_meg_merge_dwposefilter_paishe.json',
ref_image_path='/root/paddlejob/workspace/qizipeng/code/longvideogen/output.json',
time_division_factor=4,
time_division_remainder=1,
max_pixels=1920*1080,
height_division_factor=16, width_division_factor=16,
transform=None,
length=None,
resolution=None,
prev_length=5,
ref_num = 3,
training = True):
self.data_path = dataset_base_path
self.data = []
self.length = length
self.resolution = resolution
self.height, self.width = resolution
self.num_frames = length
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
self.max_pixels = max_pixels
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.prev_length = prev_length
self.training = training
self.ref_num = ref_num
with open(self.data_path, 'r') as f:
meta_datas = json.load(f)
for video_path in tqdm(meta_datas.keys()):
context = meta_datas[video_path]
candidate_labels = list(context.keys())
candidate_labels.remove('text')
disk_path = meta_datas[video_path]["disk_path"]
if not disk_path.lower().endswith(".mp4"):
continue
# reader = imageio.get_reader(meta_datas[video_path]["disk_path"])
# total_original_frames = reader.count_frames()
# total_frame = total_original_frames # context["end_index"] - context["start_index"] - 1
total_frame = None
ref_id = self.get_ref_id(face_crop_angle = context['facedetect_v1'], facedetect_v1_frame_index = context['facedetect_v1_frame_index'], total_frame = total_frame)
if ref_id == []:
continue
ref_id_all = []
for ids in ref_id:
ref_id_grop = []
for id in ids:
coordinate = context['facedetect_v1'][id][0]['detect']
if context['facedetect_v1'][id][0]['detect']["prob"] < 0.99:
continue
top, height, width, left = coordinate['top'], coordinate['height'], coordinate['width'], coordinate['left']
if not(min(height, width) > 80 ):
continue
# enlarge bbox 1.5x
width = int(width * 1)
height = int(height * 1)
frame_index = context['facedetect_v1_frame_index'][id]
ref_id_grop.append([top, height, width, left, int(frame_index)])
if ref_id_grop != []:
if len(ref_id_grop) >= 3: #self.ref_num: ### 为了和ref_num = 3 保持数据一致
ref_id_all.append(ref_id_grop)
if ref_id_all == []:
continue
meta_prompt = {}
meta_prompt["global_caption"] = None
meta_prompt["per_shot_prompt"] = []
meta_prompt["single_prompt"] = context['text']
self.data.append({'video_path': disk_path, 'meta_prompt': meta_prompt, "ref_id_all": ref_id_all})
# self.data.append({'video_path':video_path, 'meta_prompt': meta_prompt, "ref_id_all": ref_id_all})
random.seed(42) # 让每次划分一致(可选)
total = len(self.data)
test_count = max(1, int(total * 0.05)) # 至少一个
# 随机选择 test 的 index
test_indices = set(random.sample(range(total), test_count))
self.data_test = [self.data[i] for i in range(total) if i in test_indices]
self.data_train = [self.data[i] for i in range(total) if i not in test_indices]
print(f"🔥 数据集划分完成:Train={len(self.data_train)}, Test={len(self.data_test)}")
if self.height is not None and self.width is not None:
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
self.dynamic_resolution = False
elif self.height is None and self.width is None:
print("Height and width are none. Setting `dynamic_resolution` to True.")
self.dynamic_resolution = True
def get_ref_id(self, face_crop_angle, facedetect_v1_frame_index = None, total_frame = None, angle_threshold=50):
"""
返回满足角度差异要求的三元组 [i, j, k]
要求:
- face_crop_angle[i] / [j] / [k] 都必须非空
- i,j 两者任意 yaw/pitch/roll 差值 > angle_threshold
- k != i != j,且 k 也必须非空
"""
ref_id = []
max_try = 5
need_max = 3
try_num = 0
# 过滤空元素,保留有效索引
valid_indices = [idx for idx, item in enumerate(face_crop_angle) if item]
N = len(valid_indices)
if N < 3:
return ref_id # 不足 3 张有效图,无法组成三元组
# 两两组合检查角度差
for a in range(N - 1):
i = valid_indices[a]
# if facedetect_v1_frame_index[i] > total_frame:
# continue
angle_i = face_crop_angle[i][0]["angle"]
for b in range(a + 1, N):
j = valid_indices[b]
# if facedetect_v1_frame_index[j] > total_frame:
# continue
angle_j = face_crop_angle[j][0]["angle"]
# 判断是否满足阈值
if (
abs(angle_i["yaw"] - angle_j["yaw"]) > angle_threshold or
abs(angle_i["pitch"] - angle_j["pitch"]) > angle_threshold or
abs(angle_i["roll"] - angle_j["roll"]) > angle_threshold
):
# 找第三个 k
for c in range(N):
k = valid_indices[c]
# if facedetect_v1_frame_index[k] > total_frame:
# continue
if k != i and k != j:
ref_id.append([i, j, k])
break
try_num += 1
if try_num >= max_try or len(ref_id) >= need_max:
return ref_id
return ref_id
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def get_height_width(self, image):
if self.dynamic_resolution:
width, height = image.size
if width * height > self.max_pixels:
scale = (width * height / self.max_pixels) ** 0.5
height, width = int(height / scale), int(width / scale)
height = height // self.height_division_factor * self.height_division_factor
width = width // self.width_division_factor * self.width_division_factor
else:
height, width = self.height, self.width
return height, width
# def
# img_ratio = img.width / img.height
# target_ratio = w / h
# if img_ratio > target_ratio: # Image is wider than target
# new_width = w
# new_height = int(new_width / img_ratio)
# else: # Image is taller than target
# new_height = h
# new_width = int(new_height * img_ratio)
# # img = img.resize((new_width, new_height), Image.ANTIALIAS)
# img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# # Create a new image with the target size and place the resized image in the center
# delta_w = w - img.size[0]
# delta_h = h - img.size[1]
# padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
# new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
def resize_ref(self, img, target_h, target_w):
h = target_h
w = target_w
img = img.convert("RGB")
# Calculate the required size to keep aspect ratio and fill the rest with padding.
img_ratio = img.width / img.height
target_ratio = w / h
if img_ratio > target_ratio: # Image is wider than target
new_width = w
new_height = int(new_width / img_ratio)
else: # Image is taller than target
new_height = h
new_width = int(new_height * img_ratio)
# img = img.resize((new_width, new_height), Image.ANTIALIAS)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Create a new image with the target size and place the resized image in the center
delta_w = w - img.size[0]
delta_h = h - img.size[1]
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
return new_img
def load_video_crop_ref_image(self, video_path=None, ref_id_all=[[]]):
### fps 转化
reader = imageio.get_reader(video_path)
meta = reader.get_meta_data()
original_fps = meta.get("fps", 24)
target_fps = 16
duration_seconds = 5
target_frames = target_fps * duration_seconds + 1 # = 80 frames
# ---- 获取原视频帧数 ----
try:
total_original_frames = reader.count_frames()
except:
total_original_frames = int(meta.get("duration", 5) * original_fps)
# ---- 需要多少原始帧(5秒)----
need_orig_frames = int(original_fps * duration_seconds)
# ---- Case 1: 原视频 >= 5秒 → 随机选择 5 秒起点 ----
if total_original_frames > need_orig_frames:
max_start = total_original_frames - need_orig_frames
start_frame = random.randint(0, max_start)
segment_start = start_frame
segment_end = start_frame + need_orig_frames
else:
# ---- Case 2: 原视频 < 5秒 → 用全部帧 ----
segment_start = 0
segment_end = total_original_frames
# ---- 均匀采样 80 帧 ----
sample_ids = np.linspace(segment_start, segment_end - 1, num=target_frames, dtype=int)
frames = []
for frame_id in sample_ids:
frame = reader.get_data(int(frame_id))
frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame, *self.get_height_width(frame))
frames.append(frame)
# ===========================
# 选择参考图部分(你要求的)
# ===========================
# 1)从 ref_images_all(三维 list)里随机选一组
# ref_images_all = [ [img1, img2, img3], [imgA, imgB, imgC], ... ]
ref_group = random.choice(ref_id_all)
# 2)检查资源是否足够
if len(ref_group) < self.ref_num:
raise ValueError(f"需要 {self.ref_num} 张参考图,但该组只有 {len(ref_group)} 张。")
# 3)从该组中随机选 self.ref_num 张
selected_refs = random.sample(ref_group, self.ref_num)
random.shuffle(selected_refs)
ref_images = []
for sf in selected_refs:
top, height, width, left, frame_index = sf
# import pdb; pdb.set_trace()
if frame_index > total_original_frames:
print(f"{video_path}, frame_index({frame_index}) out of range")
frame = reader.get_data(int(frame_index))
frame = Image.fromarray(frame)
xmin, ymin, xmax, ymax = left, top, left + width, top + height
cropped_image = frame.crop((xmin, ymin, xmax, ymax)).convert("RGB")
cropped_image = self.resize_ref(cropped_image, self.height, self.width)
# Calculate the required size to keep aspect ratio and fill the rest with padding.
ref_images.append(cropped_image)
reader.close()
return frames, ref_images
def __getitem__(self, index):
max_retry = 10 # 最多重试 10 次,避免死循环
retry = 0
while retry < max_retry:
# ----- 选择 train / test 数据 -----
if self.training:
meta_data = self.data_train[index % len(self.data_train)]
else:
meta_data = self.data_test[index % len(self.data_test)]
video_path = meta_data['video_path']
meta_prompt = meta_data['meta_prompt']
ref_id_all = meta_data['ref_id_all']
# ----- 尝试读取 video + ref -----
try:
input_video, ref_images = self.load_video_crop_ref_image(
video_path=video_path,
ref_id_all=ref_id_all
)
except Exception as e:
print("❌ Exception in load_video_crop_ref_image")
print(f" video_path: {video_path}")
print(f" error type: {type(e).__name__}")
print(f" error msg : {e}")
# 打印 traceback,定位问题更容易
import traceback
traceback.print_exc()
input_video = None
ref_images = None
# ----- 如果成功,并且 video 不为空,返回结果 -----
if input_video is not None and len(input_video) > 0:
return {
"global_caption": None,
"shot_num": 1,
"pre_shot_caption": [],
"single_caption": meta_prompt["single_prompt"],
"video": input_video,
"ref_num": self.ref_num,
"ref_images": ref_images,
"video_path": video_path
}
# ----- 如果失败,换 index,并继续尝试 -----
retry += 1
index = random.randint(0, len(self.data_train) - 1 if self.training else len(self.data_test) - 1)
# 若 10 次都失败,返回最后一次的错误内容
raise RuntimeError(f"❌ [Dataset] Failed to load video/ref after {max_retry} retries.")
def __len__(self):
if self.training:
return len(self.data_train)
else:
return len(self.data_test)
if __name__ == '__main__':
from torch.utils.data import DataLoader
dataset = MulltiShot_MultiView_Dataset(length=49, resolution=(384, 640), training=True)
print(len(dataset))
metadata = dataset[0]
# results = dataset[0]
# loader = DataLoader(
# dataset,
# batch_size=1, # 视频一般 batch=1
# shuffle=False, # 你想打乱就 True
# num_workers=10, # ⭐ 重点:开启 8 个子进程并行加载
# pin_memory=True,
# prefetch_factor=2, # 每个 worker 预读取 2 个样本
# collate_fn=lambda x: x[0], # ⭐ 不做任何 collate
# )
# for batch in tqdm(loader):
# pass
for i in tqdm(range(len(dataset))):
file = dataset[i]
assert 0