| import imageio, os, torch, warnings, torchvision, argparse, json |
| from peft import LoraConfig, inject_adapter_in_model |
| from PIL import Image |
| import pandas as pd |
| from tqdm import tqdm |
| from accelerate import Accelerator |
| from accelerate.utils import DistributedDataParallelKwargs |
| import random |
| from decord import VideoReader |
| from decord import cpu, gpu |
| import imageio.v3 as iio |
|
|
| from torchvision import transforms |
| import torchvision |
| import random |
| import decord |
| from torchvision import transforms |
| import re |
| decord.bridge.set_bridge('torch') |
| import random |
| import numpy as np |
| from PIL import Image, ImageOps |
|
|
| class MulltiShot_MultiView_Dataset(torch.utils.data.Dataset): |
| def __init__(self, dataset_base_path='/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/merged_mark_paishe_ds_meg_merge_dwposefilter_paishe.json', |
| ref_image_path='/root/paddlejob/workspace/qizipeng/code/longvideogen/output.json', |
| time_division_factor=4, |
| time_division_remainder=1, |
| max_pixels=1920*1080, |
| height_division_factor=16, width_division_factor=16, |
| transform=None, |
| length=None, |
| resolution=None, |
| prev_length=5, |
| ref_num = 3, |
| training = True): |
| self.data_path = dataset_base_path |
| self.data = [] |
| self.length = length |
| self.resolution = resolution |
| self.height, self.width = resolution |
| self.num_frames = length |
| self.time_division_factor = time_division_factor |
| self.time_division_remainder = time_division_remainder |
| self.max_pixels = max_pixels |
| self.height_division_factor = height_division_factor |
| self.width_division_factor = width_division_factor |
| self.prev_length = prev_length |
| self.training = training |
| self.ref_num = ref_num |
|
|
| with open(self.data_path, 'r') as f: |
| meta_datas = json.load(f) |
| |
| for video_path in tqdm(meta_datas.keys()): |
| context = meta_datas[video_path] |
| candidate_labels = list(context.keys()) |
| candidate_labels.remove('text') |
|
|
| disk_path = meta_datas[video_path]["disk_path"] |
| if not disk_path.lower().endswith(".mp4"): |
| continue |
|
|
|
|
| |
| |
| |
| total_frame = None |
| ref_id = self.get_ref_id(face_crop_angle = context['facedetect_v1'], facedetect_v1_frame_index = context['facedetect_v1_frame_index'], total_frame = total_frame) |
| if ref_id == []: |
| continue |
| ref_id_all = [] |
| for ids in ref_id: |
| ref_id_grop = [] |
| for id in ids: |
| coordinate = context['facedetect_v1'][id][0]['detect'] |
| if context['facedetect_v1'][id][0]['detect']["prob"] < 0.99: |
| continue |
| top, height, width, left = coordinate['top'], coordinate['height'], coordinate['width'], coordinate['left'] |
| if not(min(height, width) > 80 ): |
| continue |
| |
| width = int(width * 1) |
| height = int(height * 1) |
| frame_index = context['facedetect_v1_frame_index'][id] |
| ref_id_grop.append([top, height, width, left, int(frame_index)]) |
| if ref_id_grop != []: |
| if len(ref_id_grop) >= 3: |
| ref_id_all.append(ref_id_grop) |
| if ref_id_all == []: |
| continue |
| meta_prompt = {} |
| meta_prompt["global_caption"] = None |
| meta_prompt["per_shot_prompt"] = [] |
| meta_prompt["single_prompt"] = context['text'] |
| self.data.append({'video_path': disk_path, 'meta_prompt': meta_prompt, "ref_id_all": ref_id_all}) |
| |
|
|
| random.seed(42) |
| total = len(self.data) |
| test_count = max(1, int(total * 0.05)) |
|
|
| |
| test_indices = set(random.sample(range(total), test_count)) |
|
|
| self.data_test = [self.data[i] for i in range(total) if i in test_indices] |
| self.data_train = [self.data[i] for i in range(total) if i not in test_indices] |
| print(f"🔥 数据集划分完成:Train={len(self.data_train)}, Test={len(self.data_test)}") |
|
|
| if self.height is not None and self.width is not None: |
| print("Height and width are fixed. Setting `dynamic_resolution` to False.") |
| self.dynamic_resolution = False |
| elif self.height is None and self.width is None: |
| print("Height and width are none. Setting `dynamic_resolution` to True.") |
| self.dynamic_resolution = True |
|
|
| def get_ref_id(self, face_crop_angle, facedetect_v1_frame_index = None, total_frame = None, angle_threshold=50): |
| """ |
| 返回满足角度差异要求的三元组 [i, j, k] |
| 要求: |
| - face_crop_angle[i] / [j] / [k] 都必须非空 |
| - i,j 两者任意 yaw/pitch/roll 差值 > angle_threshold |
| - k != i != j,且 k 也必须非空 |
| """ |
| ref_id = [] |
| max_try = 5 |
| need_max = 3 |
| try_num = 0 |
| |
| |
| valid_indices = [idx for idx, item in enumerate(face_crop_angle) if item] |
| N = len(valid_indices) |
|
|
| if N < 3: |
| return ref_id |
|
|
| |
| for a in range(N - 1): |
| i = valid_indices[a] |
| |
| |
| angle_i = face_crop_angle[i][0]["angle"] |
|
|
| for b in range(a + 1, N): |
| j = valid_indices[b] |
| |
| |
| angle_j = face_crop_angle[j][0]["angle"] |
|
|
| |
| if ( |
| abs(angle_i["yaw"] - angle_j["yaw"]) > angle_threshold or |
| abs(angle_i["pitch"] - angle_j["pitch"]) > angle_threshold or |
| abs(angle_i["roll"] - angle_j["roll"]) > angle_threshold |
| ): |
| |
| for c in range(N): |
| k = valid_indices[c] |
| |
| |
| if k != i and k != j: |
| ref_id.append([i, j, k]) |
| break |
|
|
| try_num += 1 |
| if try_num >= max_try or len(ref_id) >= need_max: |
| return ref_id |
|
|
| return ref_id |
| def crop_and_resize(self, image, target_height, target_width): |
| width, height = image.size |
| scale = max(target_width / width, target_height / height) |
| image = torchvision.transforms.functional.resize( |
| image, |
| (round(height*scale), round(width*scale)), |
| interpolation=torchvision.transforms.InterpolationMode.BILINEAR |
| ) |
| image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) |
| return image |
| |
| def get_height_width(self, image): |
| if self.dynamic_resolution: |
| width, height = image.size |
| if width * height > self.max_pixels: |
| scale = (width * height / self.max_pixels) ** 0.5 |
| height, width = int(height / scale), int(width / scale) |
| height = height // self.height_division_factor * self.height_division_factor |
| width = width // self.width_division_factor * self.width_division_factor |
| else: |
| height, width = self.height, self.width |
| return height, width |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| def resize_ref(self, img, target_h, target_w): |
| h = target_h |
| w = target_w |
| img = img.convert("RGB") |
| |
| img_ratio = img.width / img.height |
| target_ratio = w / h |
| |
| if img_ratio > target_ratio: |
| new_width = w |
| new_height = int(new_width / img_ratio) |
| else: |
| new_height = h |
| new_width = int(new_height * img_ratio) |
| |
| |
| img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) |
|
|
| |
| delta_w = w - img.size[0] |
| delta_h = h - img.size[1] |
| padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) |
| new_img = ImageOps.expand(img, padding, fill=(255, 255, 255)) |
|
|
| return new_img |
|
|
|
|
| def load_video_crop_ref_image(self, video_path=None, ref_id_all=[[]]): |
| |
| reader = imageio.get_reader(video_path) |
| meta = reader.get_meta_data() |
| original_fps = meta.get("fps", 24) |
| target_fps = 16 |
| duration_seconds = 5 |
| target_frames = target_fps * duration_seconds + 1 |
|
|
| |
| try: |
| total_original_frames = reader.count_frames() |
| except: |
| total_original_frames = int(meta.get("duration", 5) * original_fps) |
|
|
|
|
|
|
| |
| need_orig_frames = int(original_fps * duration_seconds) |
|
|
| |
| if total_original_frames > need_orig_frames: |
| max_start = total_original_frames - need_orig_frames |
| start_frame = random.randint(0, max_start) |
| segment_start = start_frame |
| segment_end = start_frame + need_orig_frames |
| else: |
| |
| segment_start = 0 |
| segment_end = total_original_frames |
|
|
| |
| sample_ids = np.linspace(segment_start, segment_end - 1, num=target_frames, dtype=int) |
|
|
| frames = [] |
| for frame_id in sample_ids: |
| frame = reader.get_data(int(frame_id)) |
| frame = Image.fromarray(frame) |
| frame = self.crop_and_resize(frame, *self.get_height_width(frame)) |
| frames.append(frame) |
|
|
| |
| |
| |
|
|
| |
| |
| ref_group = random.choice(ref_id_all) |
|
|
| |
| if len(ref_group) < self.ref_num: |
| raise ValueError(f"需要 {self.ref_num} 张参考图,但该组只有 {len(ref_group)} 张。") |
|
|
| |
| selected_refs = random.sample(ref_group, self.ref_num) |
| random.shuffle(selected_refs) |
| |
| ref_images = [] |
| for sf in selected_refs: |
| top, height, width, left, frame_index = sf |
| |
| if frame_index > total_original_frames: |
| print(f"{video_path}, frame_index({frame_index}) out of range") |
| frame = reader.get_data(int(frame_index)) |
| frame = Image.fromarray(frame) |
| xmin, ymin, xmax, ymax = left, top, left + width, top + height |
| cropped_image = frame.crop((xmin, ymin, xmax, ymax)).convert("RGB") |
| cropped_image = self.resize_ref(cropped_image, self.height, self.width) |
| |
| ref_images.append(cropped_image) |
| reader.close() |
|
|
| return frames, ref_images |
|
|
| def __getitem__(self, index): |
| max_retry = 10 |
| retry = 0 |
|
|
| while retry < max_retry: |
| |
| if self.training: |
| meta_data = self.data_train[index % len(self.data_train)] |
| else: |
| meta_data = self.data_test[index % len(self.data_test)] |
|
|
| video_path = meta_data['video_path'] |
| meta_prompt = meta_data['meta_prompt'] |
| ref_id_all = meta_data['ref_id_all'] |
|
|
| |
| try: |
| input_video, ref_images = self.load_video_crop_ref_image( |
| video_path=video_path, |
| ref_id_all=ref_id_all |
| ) |
| except Exception as e: |
| print("❌ Exception in load_video_crop_ref_image") |
| print(f" video_path: {video_path}") |
| print(f" error type: {type(e).__name__}") |
| print(f" error msg : {e}") |
|
|
| |
| import traceback |
| traceback.print_exc() |
| input_video = None |
| ref_images = None |
| |
| if input_video is not None and len(input_video) > 0: |
| return { |
| "global_caption": None, |
| "shot_num": 1, |
| "pre_shot_caption": [], |
| "single_caption": meta_prompt["single_prompt"], |
| "video": input_video, |
| "ref_num": self.ref_num, |
| "ref_images": ref_images, |
| "video_path": video_path |
| } |
|
|
| |
| retry += 1 |
| index = random.randint(0, len(self.data_train) - 1 if self.training else len(self.data_test) - 1) |
|
|
| |
| raise RuntimeError(f"❌ [Dataset] Failed to load video/ref after {max_retry} retries.") |
|
|
| def __len__(self): |
| if self.training: |
| return len(self.data_train) |
| else: |
| return len(self.data_test) |
|
|
| if __name__ == '__main__': |
| from torch.utils.data import DataLoader |
| dataset = MulltiShot_MultiView_Dataset(length=49, resolution=(384, 640), training=True) |
| print(len(dataset)) |
| metadata = dataset[0] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| for i in tqdm(range(len(dataset))): |
| file = dataset[i] |
|
|
| assert 0 |
| |