| import imageio, os, torch, warnings, torchvision, argparse, json |
| from peft import LoraConfig, inject_adapter_in_model |
| from PIL import Image |
| import pandas as pd |
| from tqdm import tqdm |
| from accelerate import Accelerator |
| from accelerate.utils import DataLoaderConfiguration, set_seed |
| import wandb |
| import tempfile |
| import requests |
| import random |
| import decord |
| import cv2 |
| import numpy as np |
| import shutil |
| import imageio.v3 as iio |
| import imageio_ffmpeg as ffmpeg |
| import math |
| import re |
| from diffsynth.trainers.timer import get_timers |
| import time |
| import glob |
| from safetensors.torch import save_file, load_file |
| import math |
| import random |
| from typing import List, Dict, Iterable, Iterator, Sequence, Optional |
| import torch |
| from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler, DistributedSampler |
| import torch.distributed as dist |
| from tqdm import tqdm |
| from multiprocessing import Pool, cpu_count |
| import math |
| import random |
| from collections import defaultdict, deque |
| from typing import Iterable, List, Sequence, Optional, Callable, Dict |
|
|
| import torch |
| from torch.utils.data import DataLoader, Sampler, BatchSampler, RandomSampler, DistributedSampler |
| from accelerate.utils import DataLoaderConfiguration, DeepSpeedPlugin |
| import random |
| import matplotlib.pyplot as plt |
|
|
| class DiffusionTrainingModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| |
| def to(self, *args, **kwargs): |
| for name, model in self.named_children(): |
| model.to(*args, **kwargs) |
| return self |
| |
| |
| def trainable_modules(self): |
| trainable_modules = filter(lambda p: p.requires_grad, self.parameters()) |
| return trainable_modules |
| |
| |
| def trainable_param_names(self): |
| trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters())) |
| trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) |
| return trainable_param_names |
| |
| |
| def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None): |
| if lora_alpha is None: |
| lora_alpha = lora_rank |
| lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules) |
| model = inject_adapter_in_model(lora_config, model) |
| return model |
| |
| |
| def export_trainable_state_dict(self, state_dict, remove_prefix=None): |
| trainable_param_names = self.trainable_param_names() |
| state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names} |
| if remove_prefix is not None: |
| state_dict_ = {} |
| for name, param in state_dict.items(): |
| if name.startswith(remove_prefix): |
| name = name[len(remove_prefix):] |
| state_dict_[name] = param |
| state_dict = state_dict_ |
| return state_dict |
|
|
|
|
|
|
| class ModelLogger: |
| """ |
| 一个集成了 Accelerate、W&B、TensorBoard 和文件日志的日志记录器。 |
| 现在它也负责跟踪全局训练步数,并能被 Accelerator 保存和恢复。 |
| """ |
| def __init__(self, accelerator: Accelerator, output_path: str): |
| self.accelerator = accelerator |
| self.output_path = output_path |
| |
| |
| self.global_step = 0 |
| |
| |
| self.accelerator.register_for_checkpointing(self) |
|
|
| self.timers = get_timers() |
| |
| |
| log_dir = os.path.join(self.output_path, "log") |
| if self.accelerator.is_main_process: |
| os.makedirs(log_dir, exist_ok=True) |
| self.accelerator.wait_for_everyone() |
| |
| log_file_path = os.path.join(log_dir, f"rank_{self.accelerator.process_index}.log") |
| |
| self.file_logger = open(log_file_path, "a+") |
| self.log_to_file(f"Logger initialized. Current global_step: {self.global_step}") |
| |
| def state_dict(self): |
| """ |
| 返回需要被保存的状态。这是 `register_for_checkpointing` 要求的方法。 |
| """ |
| return {"global_step": self.global_step} |
|
|
| def load_state_dict(self, state_dict): |
| """ |
| 从 state_dict 中加载状态。这是 `register_for_checkpointing` 要求的方法。 |
| """ |
| self.global_step = state_dict["global_step"] |
| |
| self.log_to_file(f"Logger state restored. Resumed global_step: {self.global_step}") |
|
|
| def log_to_file(self, message: str): |
| timestamp = time.strftime("%Y-%m-%d %H:%M:%S") |
| self.file_logger.write(f"[{timestamp}] {message}\n") |
| self.file_logger.flush() |
|
|
| def on_step_end(self, loss: torch.Tensor, optimizer: torch.optim.Optimizer): |
| """在每个实际的优化器步骤后调用。""" |
| step_time_ms = self.timers.step_time.stop() |
| step_time_s = step_time_ms / 1000.0 |
|
|
| gathered_losses = self.accelerator.gather_for_metrics(loss.detach()) |
| avg_loss = torch.mean(gathered_losses).item() |
| local_loss = loss.item() |
| learning_rate = optimizer.param_groups[0]['lr'] |
|
|
| if self.accelerator.is_main_process: |
| log_dict = { |
| "train/loss": avg_loss, |
| "train/learning_rate": learning_rate, |
| "perf/step_time_seconds": step_time_s, |
| "progress/epoch": self.current_epoch, |
| "progress/step": self.global_step, |
| } |
| self.accelerator.log(log_dict, step=self.global_step) |
|
|
| log_message = ( |
| f"Epoch: {self.current_epoch} | " |
| f"Iter: {self.global_step:06d} | " |
| f"Step Time: {step_time_s:.3f}s | " |
| f"Local Loss: {local_loss:.4f} | " |
| f"Avg Loss: {avg_loss:.4f} | " |
| f"LR: {learning_rate:.6f}" |
| ) |
| self.log_to_file(log_message) |
| |
| self.global_step += 1 |
| self.step_start_time = time.time() |
|
|
| def set_epoch(self, epoch: int): |
| """由外部训练循环设置当前 epoch。""" |
| self.current_epoch = epoch |
|
|
| def close(self): |
| self.log_to_file("Logger closing.") |
| self.file_logger.close() |
|
|
|
|
| class CheckpointManager: |
| """ |
| 使用 Accelerator 管理检查点的保存和加载。 |
| - 只保存可训练的权重以节省空间。 |
| - 自动轮换检查点,只保留最新的 N 个。 |
| """ |
| def __init__(self, accelerator: Accelerator, output_path: str, save_steps: int, save_epoches: int, max_to_keep: int, state_dict_converter=lambda x:x): |
| self.accelerator = accelerator |
| self.checkpoints_dir = os.path.join(output_path) |
| self.save_steps = save_steps |
| self.save_epoches = save_epoches |
| self.max_to_keep = max_to_keep |
| self.state_dict_converter = state_dict_converter |
| |
| if self.accelerator.is_main_process: |
| os.makedirs(self.checkpoints_dir, exist_ok=True) |
|
|
| |
| def get_deepspeed_engine(self, model): |
| try: |
| import deepspeed |
| if isinstance(model, deepspeed.DeepSpeedEngine): |
| return model |
| except Exception: |
| pass |
| st = getattr(self.accelerator, "state", None) |
| if st is not None: |
| if getattr(st, "deepspeed_engine", None) is not None: |
| return st.deepspeed_engine |
| plugin = getattr(st, "deepspeed_plugin", None) |
| if plugin is not None: |
| if getattr(plugin, "engine", None) is not None: |
| return plugin.engine |
| if getattr(plugin, "deepspeed_engine", None) is not None: |
| return plugin.deepspeed_engine |
| return None |
|
|
| def save_checkpoint(self, model: torch.nn.Module, global_step: int, epoch_id: int): |
| """ |
| Zero-2 + Accelerate 场景下的推荐写法: |
| - 所有 rank:accelerator.save_state() 保存 deepspeed + optimizer + scheduler + RNG |
| - rank0 再额外导出 trainable 权重为 weights.safetensors |
| """ |
| if not (global_step > 0 and global_step % self.save_steps == 0): |
| return |
|
|
| checkpoint_name = f"checkpoint-step-{global_step}-epoch-{epoch_id + 1}" |
| save_path = os.path.join(self.checkpoints_dir, checkpoint_name) |
|
|
| |
| if self.accelerator.is_main_process: |
| os.makedirs(save_path, exist_ok=True) |
| self.accelerator.wait_for_everyone() |
|
|
| def save_epoch_checkpoint(self, model: torch.nn.Module, global_step: int, epoch_id: int): |
| if self.save_epoches <= 0: |
| return |
| if (epoch_id + 1) % self.save_epoches != 0: |
| return |
|
|
| checkpoint_name = f"checkpoint-step-{global_step}-epoch-{epoch_id + 1}" |
| save_path = os.path.join(self.checkpoints_dir, checkpoint_name) |
|
|
| if self.accelerator.is_main_process: |
| os.makedirs(save_path, exist_ok=True) |
| self.accelerator.wait_for_everyone() |
|
|
| self.accelerator.print(f"[CKPT] epoch save_state -> {save_path}") |
| self.accelerator.save_state(save_path) |
|
|
| if self.accelerator.is_main_process: |
| with open(os.path.join(save_path, "trainer_state.json"), "w") as f: |
| json.dump({"global_step": global_step}, f) |
| state_dict = self.accelerator.get_state_dict(model) |
| state_dict = self.accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix="pipe.dit.") |
| state_dict = self.state_dict_converter(state_dict) |
|
|
| weights_path = os.path.join(save_path, "weights.safetensors") |
| self.accelerator.save(state_dict, weights_path, safe_serialization=True) |
| self.accelerator.print(f"[CKPT] epoch trainable weights saved -> {weights_path}") |
| self._rotate_checkpoints() |
| self.accelerator.wait_for_everyone() |
| |
| |
| self.accelerator.print(f"[CKPT] accelerator.save_state -> {save_path}") |
| self.accelerator.save_state(save_path) |
|
|
| |
| if self.accelerator.is_main_process: |
| with open(os.path.join(save_path, "trainer_state.json"), "w") as f: |
| json.dump({"global_step": global_step}, f) |
| state_dict = self.accelerator.get_state_dict(model) |
| state_dict = self.accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix="pipe.dit.") |
| state_dict = self.state_dict_converter(state_dict) |
|
|
| weights_path = os.path.join(save_path, "weights.safetensors") |
| self.accelerator.save(state_dict, weights_path, safe_serialization=True) |
| self.accelerator.print(f"[CKPT] trainable weights saved -> {weights_path}") |
| |
| self._rotate_checkpoints() |
| self.accelerator.wait_for_everyone() |
|
|
| def load_checkpoint(self, model, resume_from_checkpoint): |
| if not resume_from_checkpoint: |
| return None |
|
|
| |
| if resume_from_checkpoint == "latest": |
| load_path = self._get_load_path(resume_from_checkpoint) |
| if not load_path: |
| self.accelerator.print("No checkpoint found to resume from. Starting from scratch.") |
| return |
| else: |
| load_path = resume_from_checkpoint |
|
|
| if load_path is None: |
| self.accelerator.print("[Resume] No checkpoint found") |
| return None |
|
|
| self.accelerator.print(f"[Resume] Loading from: {load_path}") |
|
|
| |
| try: |
| self.accelerator.load_state(load_path) |
| self.accelerator.print("[Resume] accelerator.load_state() DONE.") |
| except Exception as e: |
| self.accelerator.print(f"[Resume] ERROR: {e}") |
|
|
| |
| self.accelerator.wait_for_everyone() |
| |
|
|
| self.accelerator.print("[Resume] Model fully restored.") |
| return load_path |
|
|
|
|
| def _get_load_path(self, resume_from_checkpoint: str | bool) -> str | None: |
| """辅助函数,解析 resume_from_checkpoint 并返回有效的加载路径。""" |
| if resume_from_checkpoint is True or str(resume_from_checkpoint).lower() == "latest": |
| return self.find_latest_checkpoint() |
| |
| |
| if os.path.isdir(resume_from_checkpoint): |
| return resume_from_checkpoint |
| return None |
|
|
| def find_latest_checkpoint(self) -> str | None: |
| """在检查点目录中找到最新的检查点。""" |
| all_checkpoints = glob.glob(os.path.join(self.checkpoints_dir, "checkpoint-step-*")) |
| if not all_checkpoints: |
| return None |
| |
| try: |
| latest_checkpoint = max( |
| all_checkpoints, |
| key=lambda path: int(re.search(r'checkpoint-step-(\d+)', path).group(1)) |
| ) |
| return latest_checkpoint |
| except (ValueError, AttributeError): |
| |
| return None |
|
|
| def _rotate_checkpoints(self): |
| """删除旧的检查点,只保留 `max_to_keep` 个。""" |
| if self.max_to_keep <= 0: |
| return |
|
|
| all_checkpoints = glob.glob(os.path.join(self.checkpoints_dir, "checkpoint-step-*")) |
| |
| |
| try: |
| sorted_checkpoints = sorted( |
| all_checkpoints, |
| key=lambda path: int(re.search(r'checkpoint-step-(\d+)-epoch-\d+', path).group(1)) |
| ) |
| except (ValueError, AttributeError): |
| self.accelerator.print("[Warning] Could not sort checkpoints for rotation due to naming issues.") |
| return |
|
|
| |
| num_to_delete = len(sorted_checkpoints) - self.max_to_keep |
| if num_to_delete > 0: |
| checkpoints_to_delete = sorted_checkpoints[:num_to_delete] |
| for ckpt_path in checkpoints_to_delete: |
| self.accelerator.print(f"Deleting old checkpoint: {ckpt_path}") |
| shutil.rmtree(ckpt_path) |
|
|
|
|
|
|
| def launch_training_task( |
| args, |
| dataset: torch.utils.data.Dataset, |
| model: DiffusionTrainingModule, |
| optimizer: torch.optim.Optimizer, |
| scheduler: torch.optim.lr_scheduler.LRScheduler, |
| num_epochs: int = 1, |
| gradient_accumulation_steps: int = 1, |
| output_path: str = "./models/train", |
| save_steps: int = 10, |
| save_epoches: int = 1, |
| max_checkpoints_to_keep=5, |
| resume_from_checkpoint: str | bool = "latest", |
| seed: int = 42, |
| visual_log_project_name: str=None, |
| ): |
| def collate_fn_identity(batch): |
| return batch |
| dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size = args.batch_size, collate_fn = collate_fn_identity, num_workers = 8) |
| |
| |
| |
| accelerator = Accelerator( |
| gradient_accumulation_steps=gradient_accumulation_steps, |
| |
| project_dir=output_path, |
| ) |
| set_seed(seed, device_specific=True) |
| |
| |
| |
| if accelerator.is_main_process: |
| accelerator.init_trackers(project_name=visual_log_project_name) |
| model_logger = ModelLogger(accelerator, output_path) |
| checkpoint_manager = CheckpointManager(accelerator, output_path, save_steps, save_epoches, max_checkpoints_to_keep) |
|
|
| |
| |
| model, optimizer, dataloader, scheduler = accelerator.prepare( |
| model, optimizer, dataloader, scheduler |
| ) |
|
|
| |
| if resume_from_checkpoint: |
| load_path = checkpoint_manager.load_checkpoint(model, "latest") |
| if load_path: |
| state_path = os.path.join(load_path, "trainer_state.json") |
| if os.path.exists(state_path): |
| state = json.load(open(state_path)) |
| global_step = state.get("global_step", 0) |
| model_logger.global_step = global_step |
| else: |
| accelerator.print(f"[Resume] trainer_state.json not found in {load_path}, start from step 0.") |
| model_logger.global_step = 0 |
| else: |
| accelerator.print("[Resume] No valid checkpoint path, start from scratch.") |
| model_logger.global_step = 0 |
| |
| |
| |
|
|
|
|
| |
| |
| total_steps_per_epoch = len(dataloader) |
| resume_step = model_logger.global_step % total_steps_per_epoch |
| starting_epoch = model_logger.global_step // total_steps_per_epoch |
| |
| accelerator.print("--- Starting Training ---") |
| accelerator.print(f"Num Epochs: {num_epochs}") |
| accelerator.print(f"Total steps per epoch: {total_steps_per_epoch}") |
| accelerator.print(f"Resuming from Epoch: {starting_epoch}, Step: {resume_step}") |
| accelerator.print(f"Total number of data: {len(dataloader)}") |
|
|
| timers = get_timers() |
| if accelerator.is_main_process: |
| save_loss_path = os.path.join(output_path, "training_loss_plots", args.visual_log_project_name) |
| loss_means = [] |
| os.makedirs(save_loss_path, exist_ok=True) |
|
|
| def _save_video_frames(frames, path, fps=16): |
| if not frames: |
| return |
| writer = imageio.get_writer(path, fps=fps) |
| try: |
| for frame in frames: |
| if isinstance(frame, torch.Tensor): |
| frame = frame.detach().cpu().numpy() |
| if hasattr(frame, "convert"): |
| frame = np.asarray(frame.convert("RGB")) |
| frame = np.asarray(frame) |
| if frame.dtype != np.uint8: |
| frame = np.clip(frame, 0, 255).astype(np.uint8) |
| writer.append_data(frame) |
| finally: |
| writer.close() |
|
|
| def _run_debug_infer(epoch_id): |
| if not getattr(args, "debug_infer", False): |
| return |
| interval = int(getattr(args, "debug_infer_interval", 1)) |
| if interval <= 0 or epoch_id % interval != 0: |
| return |
| if not accelerator.is_main_process: |
| return |
|
|
| indices = getattr(args, "debug_infer_indices", [0]) |
| if isinstance(indices, int): |
| indices = [indices] |
| if not isinstance(indices, (list, tuple)): |
| indices = [0] |
|
|
| debug_dir = os.path.join(output_path, "debug_infer", f"epoch_{epoch_id}") |
| os.makedirs(debug_dir, exist_ok=True) |
| debug_log = os.path.join(debug_dir, "debug.log") |
|
|
| def log(msg): |
| print(msg) |
| with open(debug_log, "a", encoding="utf-8") as f: |
| f.write(msg + "\n") |
|
|
| model_unwrapped = accelerator.unwrap_model(model) |
| model_unwrapped.eval() |
| pipe = model_unwrapped.pipe |
| orig_pipe_device = getattr(pipe, "device", None) |
| orig_pipe_dtype = getattr(pipe, "torch_dtype", None) |
| orig_scheduler = pipe.scheduler |
| orig_timesteps = getattr(orig_scheduler, "timesteps", None) |
| orig_sigmas = getattr(orig_scheduler, "sigmas", None) |
| orig_training = getattr(orig_scheduler, "training", False) |
| orig_linear_weights = getattr(orig_scheduler, "linear_timesteps_weights", None) |
|
|
| pipe.device = accelerator.device |
| pipe.torch_dtype = torch.bfloat16 |
|
|
| cfg_scale = float(getattr(args, "debug_infer_cfg_scale", 5.0)) |
| cfg_scale_face = float(getattr(args, "debug_infer_cfg_scale_face", 5.0)) |
| num_steps = int(getattr(args, "debug_infer_steps", 8)) |
| seed = int(getattr(args, "debug_infer_seed", args.seed)) |
| tiled = bool(getattr(args, "debug_infer_tiled", True)) |
| use_input_video = bool(getattr(args, "debug_infer_use_input_video", False)) |
| negative_prompt = getattr(args, "debug_infer_negative_prompt", "") |
|
|
| try: |
| for idx in indices: |
| if idx < 0 or idx >= len(dataset): |
| log(f"[debug_infer] index {idx} out of range.") |
| continue |
| sample = dataset[idx] |
| input_video = sample.get("video", []) |
| ref_images = sample.get("ref_images", []) |
| if not input_video or not ref_images: |
| log(f"[debug_infer] sample {idx} missing video/ref, skip.") |
| continue |
|
|
| prompt = sample.get("pre_shot_caption", ["xxx"]) |
| if isinstance(prompt, str): |
| prompt = [prompt] |
|
|
| log("=" * 80) |
| log(f"[debug_infer] epoch={epoch_id} index={idx}") |
| log(f"video_path={sample.get('video_path')}") |
| log(f"num_frames={len(input_video)} ref_num={sample.get('ref_num')} ID_num={sample.get('ID_num')}") |
| log(f"prompt={prompt}") |
|
|
| ref_dir = os.path.join(debug_dir, f"ref_{idx}") |
| os.makedirs(ref_dir, exist_ok=True) |
| for id_i, ref_group in enumerate(ref_images): |
| for img_i, img in enumerate(ref_group): |
| img.save(os.path.join(ref_dir, f"id{id_i}_img{img_i}.png")) |
| if use_input_video: |
| _save_video_frames(input_video, os.path.join(debug_dir, f"input_{idx}.mp4"), fps=16) |
|
|
| with torch.no_grad(): |
| video, _ = pipe( |
| args=args, |
| prompt=[prompt], |
| negative_prompt=[negative_prompt], |
| input_video=[input_video] if use_input_video else None, |
| ref_images=[ref_images], |
| seed=seed, |
| tiled=tiled, |
| height=input_video[0].size[1], |
| width=input_video[0].size[0], |
| num_frames=len(input_video), |
| cfg_scale=cfg_scale, |
| cfg_scale_face=cfg_scale_face, |
| num_inference_steps=num_steps, |
| num_ref_images=sample.get("ref_num"), |
| ) |
| _save_video_frames(video, os.path.join(debug_dir, f"output_{idx}.mp4"), fps=16) |
| finally: |
| if orig_pipe_device is not None: |
| pipe.device = orig_pipe_device |
| if orig_pipe_dtype is not None: |
| pipe.torch_dtype = orig_pipe_dtype |
| if orig_timesteps is not None: |
| orig_scheduler.timesteps = orig_timesteps |
| if orig_sigmas is not None: |
| orig_scheduler.sigmas = orig_sigmas |
| orig_scheduler.training = orig_training |
| if orig_linear_weights is not None: |
| orig_scheduler.linear_timesteps_weights = orig_linear_weights |
| model_unwrapped.train() |
| for epoch_id in range(starting_epoch, num_epochs): |
| model.train() |
| model_logger.set_epoch(epoch_id) |
| |
| |
| if epoch_id == starting_epoch and resume_step > 0: |
| |
| active_dataloader = enumerate(dataloader) |
| for _ in range(resume_step): |
| next(active_dataloader) |
| else: |
| active_dataloader = enumerate(dataloader) |
|
|
| pbar = tqdm( |
| active_dataloader, |
| initial=resume_step if epoch_id == starting_epoch else 0, |
| total=total_steps_per_epoch, |
| disable=not accelerator.is_main_process, |
| desc=f"Epoch {epoch_id}" |
| ) |
|
|
| for step, data in pbar: |
| timers.step_time.start() |
| with accelerator.accumulate(model): |
| optimizer.zero_grad() |
| loss = model(data, args) |
| accelerator.backward(loss) |
| optimizer.step() |
| scheduler.step() |
| |
| if accelerator.sync_gradients: |
| model_logger.on_step_end(loss, optimizer) |
| checkpoint_manager.save_checkpoint(model, model_logger.global_step, epoch_id) |
| |
| loss_mean = accelerator.gather(loss.detach()).mean().item() |
| if model_logger.global_step % 10 == 0: |
| if accelerator.is_main_process: |
| loss_means.append(loss_mean) |
| plt.figure(figsize=(8, 5)) |
| plt.plot(loss_means, marker="o", linestyle="-", label="Training Loss") |
| plt.xlabel("X (every 10 steps)") |
| plt.ylabel("Loss") |
| plt.title(f"Loss Curve up to step {model_logger.global_step}") |
| plt.grid(True) |
| plt.legend() |
| plt.savefig(f"{save_loss_path}/loss_mean.png") |
| plt.close() |
|
|
| checkpoint_manager.save_epoch_checkpoint(model, model_logger.global_step, epoch_id) |
| _run_debug_infer(epoch_id) |
|
|
| |
| resume_step = 0 |
|
|
| |
| accelerator.wait_for_everyone() |
| accelerator.print("--- Training Finished ---") |
| |
| |
| if accelerator.is_main_process: |
| accelerator.print("Saving final model...") |
| final_save_path = os.path.join(checkpoint_manager.checkpoints_dir, "final_model") |
| accelerator.save_state(final_save_path) |
| |
| accelerator.end_training() |
| model_logger.close() |
|
|
| |
| |
|
|
|
|
| def wan_parser(): |
| parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| parser.add_argument("--dataset_base_path", type=str, default="", required=False, help="Base path of the dataset.") |
| parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.") |
| parser.add_argument("--max_pixels", type=int, default=1280*720, help="Maximum number of pixels per frame, used for dynamic resolution..") |
| parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.") |
| parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.") |
| parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.") |
| parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.") |
| parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.") |
| parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.") |
| parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.") |
| parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") |
| parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") |
| parser.add_argument("--output_path", type=str, default="./checkpoints", help="Output save path.") |
| parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.") |
| parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.") |
| parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.") |
| parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.") |
| parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") |
| parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") |
| parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") |
| parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") |
| parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") |
| parser.add_argument("--save_steps", type=int, default=10, help="The epoch to save.") |
| parser.add_argument("--train_yaml", type=str, default="../../../../conf/config.yaml", help="The train yaml file.") |
| parser.add_argument("--max_checkpoints_to_keep", type=int, default=5, help="max_checkpoints_to_keep") |
| parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="The dataset yaml file") |
| parser.add_argument("--seed", type=int, default=42, help="The random seed") |
| parser.add_argument("--visual_log_project_name", type=str, default=None, help="Project name") |
| parser.add_argument("--max_frames_per_batch", type=int, default=136, help="Max frames num for each batch") |
| parser.add_argument("--prompt_index", type=int, default=-1, help="Prompt index for generation") |
| parser.add_argument("--ref_num", type=int, default=3, help="The number of reference images") |
| parser.add_argument("--local_model_path", type=str, default="", help="The default root path of the Wan weights") |
| parser.add_argument("--batch_size", type=int, default=1, help="The default batch size") |
| parser.add_argument("--save_epoches", type=int, default=1, help="The default saving epoch") |
| parser.add_argument("--split_rope", type=bool, default=False, help="Whether apply different rope into reference images ") |
| parser.add_argument("--split1", type=bool, default=False, help=" ") |
| parser.add_argument("--split2", type=bool, default=False, help=" ") |
| parser.add_argument("--split3", type=bool, default=False, help=" ") |
| parser.add_argument("--shot_rope", type=bool, default=False, help="Whether apply shot rope for multi-shot video") |
|
|
| return parser |
|
|
|
|
| if __name__ == '__main__': |
|
|
| dataset = VideoDataset(use_history = True, base_path = "/user/kg-aigc/rd_dev/qizipeng/luo_data_sorted_full.json") |
|
|
| train_dataset = dataset.data |
| for d in tqdm(train_dataset[7772:7773]): |
| |
| |
| |
|
|
| video = dataset.get_video_from_path( |
| d["videoHis_path"], d["videoHis_time"], is_history=True |
| ) |
| if video == None: |
| print(d) |
|
|
|
|
| import pdb; pdb.set_trace() |
|
|