| | import torch |
| | import torch.nn as nn |
| | import lightning as pl |
| | import wandb |
| | import os |
| | import copy |
| | from diffsynth import WanVideoReCamMasterPipeline, ModelManager |
| | import json |
| | import numpy as np |
| | from PIL import Image |
| | import imageio |
| | import random |
| | from torchvision.transforms import v2 |
| | from einops import rearrange |
| | from pose_classifier import PoseClassifier |
| | from scipy.spatial.transform import Rotation as R |
| | import traceback |
| | import argparse |
| |
|
| | def compute_relative_pose_matrix(pose1, pose2): |
| | """ |
| | 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel] |
| | |
| | 参数: |
| | pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1] |
| | pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2] |
| | |
| | 返回: |
| | relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel |
| | """ |
| | |
| | t1 = pose1[:3] |
| | q1 = pose1[3:] |
| | t2 = pose2[:3] |
| | q2 = pose2[3:] |
| | |
| | |
| | rot1 = R.from_quat(q1) |
| | rot2 = R.from_quat(q2) |
| | rot_rel = rot2 * rot1.inv() |
| | R_rel = rot_rel.as_matrix() |
| | |
| | |
| | R1_T = rot1.as_matrix().T |
| | t_rel = R1_T @ (t2 - t1) |
| | |
| | |
| | relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) |
| | |
| | return relative_matrix |
| |
|
| |
|
| | class SpatialVidFramePackDataset(torch.utils.data.Dataset): |
| | """支持FramePack机制的SpatialVid数据集""" |
| | |
| | def __init__(self, base_path, steps_per_epoch, |
| | min_condition_frames=10, max_condition_frames=40, |
| | target_frames=10, height=900, width=1600): |
| | self.base_path = base_path |
| | self.scenes_path = base_path |
| | self.min_condition_frames = min_condition_frames |
| | self.max_condition_frames = max_condition_frames |
| | self.target_frames = target_frames |
| | self.height = height |
| | self.width = width |
| | self.steps_per_epoch = steps_per_epoch |
| | self.pose_classifier = PoseClassifier() |
| | |
| | |
| | self.time_compression_ratio = 4 |
| | |
| | |
| | self.scene_dirs = [] |
| | if os.path.exists(self.scenes_path): |
| | for item in os.listdir(self.scenes_path): |
| | scene_dir = os.path.join(self.scenes_path, item) |
| | if os.path.isdir(scene_dir): |
| | encoded_path = os.path.join(scene_dir, "encoded_video.pth") |
| | if os.path.exists(encoded_path): |
| | self.scene_dirs.append(scene_dir) |
| | |
| | print(f"🔧 找到 {len(self.scene_dirs)} 个SpatialVid场景") |
| | assert len(self.scene_dirs) > 0, "No encoded scenes found!" |
| |
|
| | def select_dynamic_segment_framepack(self, full_latents): |
| | """🔧 FramePack风格的动态选择条件帧和目标帧 - SpatialVid版本""" |
| | total_lens = full_latents.shape[1] |
| | |
| | min_condition_compressed = self.min_condition_frames // self.time_compression_ratio |
| | max_condition_compressed = self.max_condition_frames // self.time_compression_ratio |
| | target_frames_compressed = self.target_frames // self.time_compression_ratio |
| | max_condition_compressed = min(max_condition_compressed, total_lens - target_frames_compressed) |
| | |
| | ratio = random.random() |
| | |
| | if ratio < 0.15: |
| | condition_frames_compressed = 1 |
| | elif 0.15 <= ratio < 0.9: |
| | condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) |
| | else: |
| | condition_frames_compressed = target_frames_compressed |
| | |
| | |
| | min_required_frames = condition_frames_compressed + target_frames_compressed |
| | if total_lens < min_required_frames: |
| | print(f"压缩后帧数不足: {total_lens} < {min_required_frames}") |
| | return None |
| | |
| | |
| | max_start = total_lens - min_required_frames - 1 |
| | start_frame_compressed = random.randint(0, max_start) |
| | |
| | condition_end_compressed = start_frame_compressed + condition_frames_compressed |
| | target_end_compressed = condition_end_compressed + target_frames_compressed |
| |
|
| | |
| | latent_indices = torch.arange(condition_end_compressed, target_end_compressed) |
| | |
| | |
| | |
| | clean_latent_indices_start = torch.tensor([start_frame_compressed]) |
| | clean_latent_1x_indices = torch.tensor([condition_end_compressed - 1]) |
| | clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices]) |
| | |
| | |
| | if condition_frames_compressed >= 2: |
| | |
| | clean_latent_2x_start = max(start_frame_compressed, condition_end_compressed - 2) |
| | clean_latent_2x_indices = torch.arange(clean_latent_2x_start-1, condition_end_compressed-1) |
| | else: |
| | |
| | clean_latent_2x_indices = torch.tensor([], dtype=torch.long) |
| | |
| | |
| | if condition_frames_compressed >= 1: |
| | |
| | clean_4x_start = max(start_frame_compressed, condition_end_compressed - 16) |
| | clean_latent_4x_indices = torch.arange(clean_4x_start-3, condition_end_compressed-3) |
| | else: |
| | clean_latent_4x_indices = torch.tensor([], dtype=torch.long) |
| | |
| | |
| | keyframe_original_idx = [] |
| | for compressed_idx in range(start_frame_compressed, target_end_compressed): |
| | keyframe_original_idx.append(compressed_idx) |
| | |
| | return { |
| | 'start_frame': start_frame_compressed, |
| | 'condition_frames': condition_frames_compressed, |
| | 'target_frames': target_frames_compressed, |
| | 'condition_range': (start_frame_compressed, condition_end_compressed), |
| | 'target_range': (condition_end_compressed, target_end_compressed), |
| | |
| | |
| | 'latent_indices': latent_indices, |
| | 'clean_latent_indices': clean_latent_indices, |
| | 'clean_latent_2x_indices': clean_latent_2x_indices, |
| | 'clean_latent_4x_indices': clean_latent_4x_indices, |
| | |
| | 'keyframe_original_idx': keyframe_original_idx, |
| | 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, |
| | 'original_target_frames': target_frames_compressed * self.time_compression_ratio, |
| | } |
| |
|
| | def create_pose_embeddings(self, cam_data, segment_info): |
| | """🔧 创建SpatialVid风格的pose embeddings - camera间隔为1帧而非4帧""" |
| | cam_data_seq = cam_data['extrinsic'] |
| | |
| | |
| | |
| | keyframe_original_idx = segment_info['keyframe_original_idx'] |
| | |
| | relative_cams = [] |
| | for idx in keyframe_original_idx: |
| | if idx + 1 < len(cam_data_seq): |
| | cam_prev = cam_data_seq[idx] |
| | cam_next = cam_data_seq[idx + 1] |
| | relative_cam = compute_relative_pose_matrix(cam_prev, cam_next) |
| | relative_cams.append(torch.as_tensor(relative_cam[:3, :])) |
| | else: |
| | |
| | identity_cam = torch.zeros(3, 4) |
| | relative_cams.append(identity_cam) |
| | |
| | if len(relative_cams) == 0: |
| | return None |
| | |
| | pose_embedding = torch.stack(relative_cams, dim=0) |
| | pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') |
| | pose_embedding = pose_embedding.to(torch.bfloat16) |
| |
|
| | return pose_embedding |
| |
|
| | def prepare_framepack_inputs(self, full_latents, segment_info): |
| | """🔧 准备FramePack风格的多尺度输入 - SpatialVid版本""" |
| | |
| | if len(full_latents.shape) == 4: |
| | full_latents = full_latents.unsqueeze(0) |
| | B, C, T, H, W = full_latents.shape |
| | else: |
| | B, C, T, H, W = full_latents.shape |
| | |
| | |
| | latent_indices = segment_info['latent_indices'] |
| | main_latents = full_latents[:, :, latent_indices, :, :] |
| | |
| | |
| | clean_latent_indices = segment_info['clean_latent_indices'] |
| | clean_latents = full_latents[:, :, clean_latent_indices, :, :] |
| | |
| | |
| | clean_latent_4x_indices = segment_info['clean_latent_4x_indices'] |
| | |
| | |
| | clean_latents_4x = torch.zeros(B, C, 16, H, W, dtype=full_latents.dtype) |
| | clean_latent_4x_indices_final = torch.full((16,), -1, dtype=torch.long) |
| | |
| | |
| | if len(clean_latent_4x_indices) > 0: |
| | actual_4x_frames = len(clean_latent_4x_indices) |
| | |
| | start_pos = max(0, 16 - actual_4x_frames) |
| | end_pos = 16 |
| | actual_start = max(0, actual_4x_frames - 16) |
| | |
| | clean_latents_4x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_4x_indices[actual_start:], :, :] |
| | clean_latent_4x_indices_final[start_pos:end_pos] = clean_latent_4x_indices[actual_start:] |
| | |
| | |
| | clean_latent_2x_indices = segment_info['clean_latent_2x_indices'] |
| | |
| | |
| | clean_latents_2x = torch.zeros(B, C, 2, H, W, dtype=full_latents.dtype) |
| | clean_latent_2x_indices_final = torch.full((2,), -1, dtype=torch.long) |
| | |
| | |
| | if len(clean_latent_2x_indices) > 0: |
| | actual_2x_frames = len(clean_latent_2x_indices) |
| | |
| | start_pos = max(0, 2 - actual_2x_frames) |
| | end_pos = 2 |
| | actual_start = max(0, actual_2x_frames - 2) |
| | |
| | clean_latents_2x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_2x_indices[actual_start:], :, :] |
| | clean_latent_2x_indices_final[start_pos:end_pos] = clean_latent_2x_indices[actual_start:] |
| | |
| | |
| | if B == 1: |
| | main_latents = main_latents.squeeze(0) |
| | clean_latents = clean_latents.squeeze(0) |
| | clean_latents_2x = clean_latents_2x.squeeze(0) |
| | clean_latents_4x = clean_latents_4x.squeeze(0) |
| | |
| | return { |
| | 'latents': main_latents, |
| | 'clean_latents': clean_latents, |
| | 'clean_latents_2x': clean_latents_2x, |
| | 'clean_latents_4x': clean_latents_4x, |
| | 'latent_indices': segment_info['latent_indices'], |
| | 'clean_latent_indices': segment_info['clean_latent_indices'], |
| | 'clean_latent_2x_indices': clean_latent_2x_indices_final, |
| | 'clean_latent_4x_indices': clean_latent_4x_indices_final, |
| | } |
| |
|
| | def __getitem__(self, index): |
| | while True: |
| | try: |
| | |
| | scene_dir = random.choice(self.scene_dirs) |
| | |
| | |
| | encoded_data = torch.load( |
| | os.path.join(scene_dir, "encoded_video.pth"), |
| | weights_only=False, |
| | map_location="cpu" |
| | ) |
| | |
| | |
| | full_latents = encoded_data['latents'] |
| | cam_data = encoded_data['cam_emb'] |
| | actual_latent_frames = full_latents.shape[1] |
| | |
| | |
| | segment_info = self.select_dynamic_segment_framepack(full_latents) |
| | if segment_info is None: |
| | continue |
| | |
| | |
| | all_camera_embeddings = self.create_pose_embeddings(cam_data, segment_info) |
| | if all_camera_embeddings is None: |
| | continue |
| | |
| | |
| | framepack_inputs = self.prepare_framepack_inputs(full_latents, segment_info) |
| | |
| | n = segment_info["condition_frames"] |
| | m = segment_info['target_frames'] |
| |
|
| | |
| | mask = torch.zeros(n+m, dtype=torch.float32) |
| | mask[:n] = 1.0 |
| | mask = mask.view(-1, 1) |
| |
|
| | |
| | camera_with_mask = torch.cat([all_camera_embeddings, mask], dim=1) |
| | |
| | result = { |
| | |
| | "latents": framepack_inputs['latents'], |
| | "clean_latents": framepack_inputs['clean_latents'], |
| | "clean_latents_2x": framepack_inputs['clean_latents_2x'], |
| | "clean_latents_4x": framepack_inputs['clean_latents_4x'], |
| | "latent_indices": framepack_inputs['latent_indices'], |
| | "clean_latent_indices": framepack_inputs['clean_latent_indices'], |
| | "clean_latent_2x_indices": framepack_inputs['clean_latent_2x_indices'], |
| | "clean_latent_4x_indices": framepack_inputs['clean_latent_4x_indices'], |
| | |
| | |
| | "camera": camera_with_mask, |
| | |
| | "prompt_emb": encoded_data["prompt_emb"], |
| | "image_emb": encoded_data.get("image_emb", {}), |
| | |
| | "condition_frames": n, |
| | "target_frames": m, |
| | "scene_name": os.path.basename(scene_dir), |
| | "dataset_name": "spatialvid", |
| | |
| | "original_condition_frames": segment_info['original_condition_frames'], |
| | "original_target_frames": segment_info['original_target_frames'], |
| | } |
| | |
| | return result |
| | |
| | except Exception as e: |
| | print(f"Error loading sample: {e}") |
| | traceback.print_exc() |
| | continue |
| | |
| | def __len__(self): |
| | return self.steps_per_epoch |
| |
|
| |
|
| | def replace_dit_model_in_manager(): |
| | """在模型加载前替换DiT模型类""" |
| | from diffsynth.models.wan_video_dit_moe import WanModelMoe |
| | from diffsynth.configs.model_config import model_loader_configs |
| | |
| | |
| | for i, config in enumerate(model_loader_configs): |
| | keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config |
| | |
| | |
| | if 'wan_video_dit' in model_names: |
| | |
| | new_model_names = [] |
| | new_model_classes = [] |
| | |
| | for name, cls in zip(model_names, model_classes): |
| | if name == 'wan_video_dit': |
| | new_model_names.append(name) |
| | new_model_classes.append(WanModelMoe) |
| | print(f"✅ 替换了模型类: {name} -> WanModelMoe") |
| | else: |
| | new_model_names.append(name) |
| | new_model_classes.append(cls) |
| | |
| | |
| | model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) |
| |
|
| |
|
| | class SpatialVidFramePackLightningModel(pl.LightningModule): |
| | def __init__( |
| | self, |
| | dit_path, |
| | learning_rate=1e-5, |
| | use_gradient_checkpointing=True, |
| | use_gradient_checkpointing_offload=False, |
| | resume_ckpt_path=None |
| | ): |
| | super().__init__() |
| | replace_dit_model_in_manager() |
| | model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") |
| | if os.path.isfile(dit_path): |
| | model_manager.load_models([dit_path]) |
| | else: |
| | dit_path = dit_path.split(",") |
| | model_manager.load_models([dit_path]) |
| | model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) |
| | |
| | self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) |
| | self.pipe.scheduler.set_timesteps(1000, training=True) |
| |
|
| | |
| | self.add_framepack_components() |
| | self.add_moe_components() |
| |
|
| | |
| | dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] |
| | for block in self.pipe.dit.blocks: |
| | block.cam_encoder = nn.Linear(13, dim) |
| | block.projector = nn.Linear(dim, dim) |
| | block.cam_encoder.weight.data.zero_() |
| | block.cam_encoder.bias.data.zero_() |
| | block.projector.weight = nn.Parameter(torch.eye(dim)) |
| | block.projector.bias = nn.Parameter(torch.zeros(dim)) |
| | |
| | if resume_ckpt_path is not None: |
| | state_dict = torch.load(resume_ckpt_path, map_location="cpu") |
| | self.pipe.dit.load_state_dict(state_dict, strict=False) |
| | print('load checkpoint:', resume_ckpt_path) |
| |
|
| | self.freeze_parameters() |
| | |
| | |
| | for name, module in self.pipe.denoising_model().named_modules(): |
| | if any(keyword in name for keyword in ["moe","sekai_processor"]): |
| | for param in module.parameters(): |
| | param.requires_grad = True |
| | |
| | self.learning_rate = learning_rate |
| | self.use_gradient_checkpointing = use_gradient_checkpointing |
| | self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload |
| | |
| | |
| | self.vis_dir = "spatialvid_framepack/visualizations" |
| | os.makedirs(self.vis_dir, exist_ok=True) |
| |
|
| | def add_framepack_components(self): |
| | """🔧 添加FramePack相关组件""" |
| | if not hasattr(self.pipe.dit, 'clean_x_embedder'): |
| | inner_dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] |
| | |
| | class CleanXEmbedder(nn.Module): |
| | def __init__(self, inner_dim): |
| | super().__init__() |
| | |
| | self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) |
| | self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) |
| | self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) |
| | |
| | def forward(self, x, scale="1x"): |
| | if scale == "1x": |
| | return self.proj(x) |
| | elif scale == "2x": |
| | return self.proj_2x(x) |
| | elif scale == "4x": |
| | return self.proj_4x(x) |
| | else: |
| | raise ValueError(f"Unsupported scale: {scale}") |
| | |
| | self.pipe.dit.clean_x_embedder = CleanXEmbedder(inner_dim) |
| | print("✅ 添加了FramePack的clean_x_embedder组件") |
| | |
| | def add_moe_components(self): |
| | """🔧 添加MoE相关组件 - 类似add_framepack_components的方式""" |
| | if not hasattr(self.pipe.dit, 'moe_config'): |
| | self.pipe.dit.moe_config = self.moe_config |
| | print("✅ 添加了MoE配置到模型") |
| | |
| | |
| | dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] |
| | unified_dim = 25 |
| | |
| | for i, block in enumerate(self.pipe.dit.blocks): |
| | from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE |
| | |
| | |
| | block.sekai_processor = ModalityProcessor("sekai", 13, unified_dim) |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | block.moe = MultiModalMoE( |
| | unified_dim=unified_dim, |
| | output_dim=dim, |
| | num_experts=1, |
| | top_k=1 |
| | ) |
| | |
| | |
| | def freeze_parameters(self): |
| | self.pipe.requires_grad_(False) |
| | self.pipe.eval() |
| | self.pipe.denoising_model().train() |
| |
|
| | def training_step(self, batch, batch_idx): |
| | """🔧 使用FramePack风格的训练步骤 - SpatialVid版本""" |
| | condition_frames = batch["condition_frames"][0].item() |
| | target_frames = batch["target_frames"][0].item() |
| | |
| | original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] |
| | original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] |
| |
|
| | dataset_name = batch.get("dataset_name", ["unknown"])[0] |
| | scene_name = batch.get("scene_name", ["unknown"])[0] |
| | |
| | |
| | latents = batch["latents"].to(self.device) |
| | if len(latents.shape) == 4: |
| | latents = latents.unsqueeze(0) |
| | |
| | |
| | clean_latents = batch["clean_latents"].to(self.device) if batch["clean_latents"].numel() > 0 else None |
| | if clean_latents is not None and len(clean_latents.shape) == 4: |
| | clean_latents = clean_latents.unsqueeze(0) |
| | |
| | clean_latents_2x = batch["clean_latents_2x"].to(self.device) if batch["clean_latents_2x"].numel() > 0 else None |
| | if clean_latents_2x is not None and len(clean_latents_2x.shape) == 4: |
| | clean_latents_2x = clean_latents_2x.unsqueeze(0) |
| | |
| | clean_latents_4x = batch["clean_latents_4x"].to(self.device) if batch["clean_latents_4x"].numel() > 0 else None |
| | if clean_latents_4x is not None and len(clean_latents_4x.shape) == 4: |
| | clean_latents_4x = clean_latents_4x.unsqueeze(0) |
| | |
| | |
| | latent_indices = batch["latent_indices"].to(self.device) |
| | clean_latent_indices = batch["clean_latent_indices"].to(self.device) if batch["clean_latent_indices"].numel() > 0 else None |
| | clean_latent_2x_indices = batch["clean_latent_2x_indices"].to(self.device) if batch["clean_latent_2x_indices"].numel() > 0 else None |
| | clean_latent_4x_indices = batch["clean_latent_4x_indices"].to(self.device) if batch["clean_latent_4x_indices"].numel() > 0 else None |
| | |
| | |
| | cam_emb = batch["camera"].to(self.device) |
| | camera_dropout_prob = 0.1 |
| | if random.random() < camera_dropout_prob: |
| | |
| | cam_emb = torch.zeros_like(cam_emb) |
| | print("应用camera dropout for CFG training") |
| | |
| | prompt_emb = batch["prompt_emb"] |
| | prompt_emb["context"] = prompt_emb["context"][0].to(self.device) |
| | image_emb = batch["image_emb"] |
| |
|
| | if "clip_feature" in image_emb: |
| | image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) |
| | if "y" in image_emb: |
| | image_emb["y"] = image_emb["y"][0].to(self.device) |
| |
|
| | |
| | self.pipe.device = self.device |
| | noise = torch.randn_like(latents) |
| | timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) |
| | timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) |
| | |
| | |
| | noisy_condition_latents = None |
| | if clean_latents is not None: |
| | noisy_condition_latents = copy.deepcopy(clean_latents) |
| | is_add_noise = random.random() |
| | if is_add_noise > 0.2: |
| | noise_cond = torch.randn_like(clean_latents) |
| | timestep_id_cond = torch.randint(0, self.pipe.scheduler.num_train_timesteps//4*3, (1,)) |
| | timestep_cond = self.pipe.scheduler.timesteps[timestep_id_cond].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) |
| | noisy_condition_latents = self.pipe.scheduler.add_noise(clean_latents, noise_cond, timestep_cond) |
| |
|
| | extra_input = self.pipe.prepare_extra_input(latents) |
| | origin_latents = copy.deepcopy(latents) |
| | noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) |
| | |
| | training_target = self.pipe.scheduler.training_target(latents, noise, timestep) |
| | |
| | |
| | noise_pred, moe_loss = self.pipe.denoising_model()( |
| | noisy_latents, |
| | timestep=timestep, |
| | cam_emb=cam_emb, |
| | |
| | modality_inputs={"sekai": cam_emb}, |
| | latent_indices=latent_indices, |
| | clean_latents=noisy_condition_latents if noisy_condition_latents is not None else clean_latents, |
| | clean_latent_indices=clean_latent_indices, |
| | clean_latents_2x=clean_latents_2x, |
| | clean_latent_2x_indices=clean_latent_2x_indices, |
| | clean_latents_4x=clean_latents_4x, |
| | clean_latent_4x_indices=clean_latent_4x_indices, |
| | **prompt_emb, |
| | **extra_input, |
| | **image_emb, |
| | use_gradient_checkpointing=self.use_gradient_checkpointing, |
| | use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload |
| | ) |
| | |
| | |
| | loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) |
| | loss = loss * self.pipe.scheduler.training_weight(timestep) |
| | print(f'--------loss ({dataset_name})------------:', loss) |
| |
|
| | return loss |
| |
|
| | def configure_optimizers(self): |
| | trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) |
| | optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) |
| | return optimizer |
| | |
| | def on_save_checkpoint(self, checkpoint): |
| | checkpoint_dir = "/share_zhuyixuan05/zhuyixuan05/ICLR2026/spatialvid/spatialvid_moe_test" |
| | os.makedirs(checkpoint_dir, exist_ok=True) |
| | |
| | current_step = self.global_step |
| | checkpoint.clear() |
| | |
| | state_dict = self.pipe.denoising_model().state_dict() |
| | torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}.ckpt")) |
| | print(f"Saved SpatialVid FramePack model checkpoint: step{current_step}.ckpt") |
| |
|
| |
|
| | def train_spatialvid_framepack(args): |
| | """训练支持FramePack机制的SpatialVid模型""" |
| | dataset = SpatialVidFramePackDataset( |
| | args.dataset_path, |
| | steps_per_epoch=args.steps_per_epoch, |
| | min_condition_frames=args.min_condition_frames, |
| | max_condition_frames=args.max_condition_frames, |
| | target_frames=args.target_frames, |
| | ) |
| | |
| | dataloader = torch.utils.data.DataLoader( |
| | dataset, |
| | shuffle=True, |
| | batch_size=1, |
| | num_workers=args.dataloader_num_workers |
| | ) |
| | |
| | model = SpatialVidFramePackLightningModel( |
| | dit_path=args.dit_path, |
| | learning_rate=args.learning_rate, |
| | use_gradient_checkpointing=args.use_gradient_checkpointing, |
| | use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, |
| | resume_ckpt_path=args.resume_ckpt_path, |
| | ) |
| |
|
| | trainer = pl.Trainer( |
| | max_epochs=args.max_epochs, |
| | accelerator="gpu", |
| | devices="auto", |
| | precision="bf16", |
| | strategy=args.training_strategy, |
| | default_root_dir=args.output_path, |
| | accumulate_grad_batches=args.accumulate_grad_batches, |
| | callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], |
| | logger=False |
| | ) |
| | trainer.fit(model, dataloader) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser(description="Train SpatialVid FramePack Dynamic ReCamMaster") |
| | parser.add_argument("--dataset_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/spatialvid") |
| | parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") |
| | parser.add_argument("--output_path", type=str, default="./") |
| | parser.add_argument("--learning_rate", type=float, default=1e-5) |
| | parser.add_argument("--steps_per_epoch", type=int, default=400) |
| | parser.add_argument("--max_epochs", type=int, default=30) |
| | parser.add_argument("--min_condition_frames", type=int, default=10, help="最小条件帧数") |
| | parser.add_argument("--max_condition_frames", type=int, default=40, help="最大条件帧数") |
| | parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") |
| | parser.add_argument("--dataloader_num_workers", type=int, default=4) |
| | parser.add_argument("--accumulate_grad_batches", type=int, default=1) |
| | parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") |
| | parser.add_argument("--use_gradient_checkpointing", action="store_true") |
| | parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") |
| | parser.add_argument("--resume_ckpt_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack/step1000_framepack.ckpt") |
| | |
| | args = parser.parse_args() |
| | |
| | print("🔧 开始训练SpatialVid FramePack模型:") |
| | print(f"📁 数据集路径: {args.dataset_path}") |
| | print(f"🎯 条件帧范围: {args.min_condition_frames}-{args.max_condition_frames}") |
| | print(f"🎯 目标帧数: {args.target_frames}") |
| | print("🔧 特殊优化:") |
| | print(" - 使用WanModelFuture模型架构") |
| | print(" - 添加FramePack多尺度输入支持") |
| | print(" - SpatialVid特有:camera间隔为1帧") |
| | print(" - CFG训练支持(10%概率camera dropout)") |
| | |
| | train_spatialvid_framepack(args) |