Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import argparse | |
| import os | |
| import cv2 | |
| from moge.model.v1 import MoGeModel | |
| import torch | |
| import numpy as np | |
| from cosmos_predict1.diffusion.inference.inference_utils import ( | |
| add_common_arguments, | |
| check_input_frames, | |
| validate_args, | |
| ) | |
| from cosmos_predict1.diffusion.inference.gen3c_pipeline import Gen3cPipeline | |
| from cosmos_predict1.utils import log, misc | |
| from cosmos_predict1.utils.io import read_prompts_from_file, save_video | |
| from cosmos_predict1.diffusion.inference.cache_3d import Cache3D_Buffer | |
| from cosmos_predict1.diffusion.inference.camera_utils import generate_camera_trajectory | |
| import torch.nn.functional as F | |
| torch.enable_grad(False) | |
| def create_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="Video to world generation demo script") | |
| # Add common arguments | |
| add_common_arguments(parser) | |
| parser.add_argument( | |
| "--prompt_upsampler_dir", | |
| type=str, | |
| default="Pixtral-12B", | |
| help="Prompt upsampler weights directory relative to checkpoint_dir", | |
| ) # TODO: do we need this? | |
| parser.add_argument( | |
| "--input_image_path", | |
| type=str, | |
| help="Input image path for generating a single video", | |
| ) | |
| parser.add_argument( | |
| "--trajectory", | |
| type=str, | |
| choices=[ | |
| "left", | |
| "right", | |
| "up", | |
| "down", | |
| "zoom_in", | |
| "zoom_out", | |
| "clockwise", | |
| "counterclockwise", | |
| "none", | |
| ], | |
| default="left", | |
| help="Select a trajectory type from the available options (default: original)", | |
| ) | |
| parser.add_argument( | |
| "--camera_rotation", | |
| type=str, | |
| choices=["center_facing", "no_rotation", "trajectory_aligned"], | |
| default="center_facing", | |
| help="Controls camera rotation during movement: center_facing (rotate to look at center), no_rotation (keep orientation), or trajectory_aligned (rotate in the direction of movement)", | |
| ) | |
| parser.add_argument( | |
| "--movement_distance", | |
| type=float, | |
| default=0.3, | |
| help="Distance of the camera from the center of the scene", | |
| ) | |
| parser.add_argument( | |
| "--noise_aug_strength", | |
| type=float, | |
| default=0.0, | |
| help="Strength of noise augmentation on warped frames", | |
| ) | |
| parser.add_argument( | |
| "--save_buffer", | |
| action="store_true", | |
| help="If set, save the warped images (buffer) side by side with the output video.", | |
| ) | |
| parser.add_argument( | |
| "--filter_points_threshold", | |
| type=float, | |
| default=0.05, | |
| help="If set, filter the points continuity of the warped images.", | |
| ) | |
| parser.add_argument( | |
| "--foreground_masking", | |
| action="store_true", | |
| help="If set, use foreground masking for the warped images.", | |
| ) | |
| return parser | |
| def parse_arguments() -> argparse.Namespace: | |
| parser = create_parser() | |
| return parser.parse_args() | |
| def validate_args(args): | |
| assert args.num_video_frames is not None, "num_video_frames must be provided" | |
| assert (args.num_video_frames - 1) % 120 == 0, "num_video_frames must be 121, 241, 361, ... (N*120+1)" | |
| def _predict_moge_depth(current_image_path: str | np.ndarray, | |
| target_h: int, target_w: int, | |
| device: torch.device, moge_model: MoGeModel): | |
| """Handles MoGe depth prediction for a single image. | |
| If the image is directly provided as a NumPy array, it should have shape [H, W, C], | |
| where the channels are RGB and the pixel values are in [0..255]. | |
| """ | |
| if isinstance(current_image_path, str): | |
| input_image_bgr = cv2.imread(current_image_path) | |
| if input_image_bgr is None: | |
| raise FileNotFoundError(f"Input image not found: {current_image_path}") | |
| input_image_rgb = cv2.cvtColor(input_image_bgr, cv2.COLOR_BGR2RGB) | |
| else: | |
| input_image_rgb = current_image_path | |
| del current_image_path | |
| depth_pred_h, depth_pred_w = 720, 1280 | |
| input_image_for_depth_resized = cv2.resize(input_image_rgb, (depth_pred_w, depth_pred_h)) | |
| input_image_for_depth_tensor_chw = torch.tensor(input_image_for_depth_resized / 255.0, dtype=torch.float32, device=device).permute(2, 0, 1) | |
| moge_output_full = moge_model.infer(input_image_for_depth_tensor_chw) | |
| moge_depth_hw_full = moge_output_full["depth"] | |
| moge_intrinsics_33_full_normalized = moge_output_full["intrinsics"] | |
| moge_mask_hw_full = moge_output_full["mask"] | |
| moge_depth_hw_full = torch.where(moge_mask_hw_full==0, torch.tensor(1000.0, device=moge_depth_hw_full.device), moge_depth_hw_full) | |
| moge_intrinsics_33_full_pixel = moge_intrinsics_33_full_normalized.clone() | |
| moge_intrinsics_33_full_pixel[0, 0] *= depth_pred_w | |
| moge_intrinsics_33_full_pixel[1, 1] *= depth_pred_h | |
| moge_intrinsics_33_full_pixel[0, 2] *= depth_pred_w | |
| moge_intrinsics_33_full_pixel[1, 2] *= depth_pred_h | |
| # Calculate scaling factor for height | |
| height_scale_factor = target_h / depth_pred_h | |
| width_scale_factor = target_w / depth_pred_w | |
| # Resize depth map, mask, and image tensor | |
| # Resizing depth: (H, W) -> (1, 1, H, W) for interpolate, then squeeze | |
| moge_depth_hw = F.interpolate( | |
| moge_depth_hw_full.unsqueeze(0).unsqueeze(0), | |
| size=(target_h, target_w), | |
| mode='bilinear', | |
| align_corners=False | |
| ).squeeze(0).squeeze(0) | |
| # Resizing mask: (H, W) -> (1, 1, H, W) for interpolate, then squeeze | |
| moge_mask_hw = F.interpolate( | |
| moge_mask_hw_full.unsqueeze(0).unsqueeze(0).to(torch.float32), | |
| size=(target_h, target_w), | |
| mode='nearest', # Using nearest neighbor for binary mask | |
| ).squeeze(0).squeeze(0).to(torch.bool) | |
| # Resizing image tensor: (C, H, W) -> (1, C, H, W) for interpolate, then squeeze | |
| input_image_tensor_chw_target_res = F.interpolate( | |
| input_image_for_depth_tensor_chw.unsqueeze(0), | |
| size=(target_h, target_w), | |
| mode='bilinear', | |
| align_corners=False | |
| ).squeeze(0) | |
| moge_image_b1chw_float = input_image_tensor_chw_target_res.unsqueeze(0).unsqueeze(1) * 2 - 1 | |
| moge_intrinsics_33 = moge_intrinsics_33_full_pixel.clone() | |
| # Adjust intrinsics for resized height | |
| moge_intrinsics_33[1, 1] *= height_scale_factor # fy | |
| moge_intrinsics_33[1, 2] *= height_scale_factor # cy | |
| moge_intrinsics_33[0, 0] *= width_scale_factor # fx | |
| moge_intrinsics_33[0, 2] *= width_scale_factor # cx | |
| moge_depth_b11hw = moge_depth_hw.unsqueeze(0).unsqueeze(0).unsqueeze(0) | |
| moge_depth_b11hw = torch.nan_to_num(moge_depth_b11hw, nan=1e4) | |
| moge_depth_b11hw = torch.clamp(moge_depth_b11hw, min=0, max=1e4) | |
| moge_mask_b11hw = moge_mask_hw.unsqueeze(0).unsqueeze(0).unsqueeze(0) | |
| # Prepare initial intrinsics [B, 1, 3, 3] | |
| moge_intrinsics_b133 = moge_intrinsics_33.unsqueeze(0).unsqueeze(0) | |
| initial_w2c_44 = torch.eye(4, dtype=torch.float32, device=device) | |
| moge_initial_w2c_b144 = initial_w2c_44.unsqueeze(0).unsqueeze(0) | |
| return ( | |
| moge_image_b1chw_float, | |
| moge_depth_b11hw, | |
| moge_mask_b11hw, | |
| moge_initial_w2c_b144, | |
| moge_intrinsics_b133, | |
| ) | |
| def _predict_moge_depth_from_tensor( | |
| image_tensor_chw_0_1: torch.Tensor, # Shape (C, H_input, W_input), range [0,1] | |
| moge_model: MoGeModel | |
| ): | |
| """Handles MoGe depth prediction from an image tensor.""" | |
| moge_output_full = moge_model.infer(image_tensor_chw_0_1) | |
| moge_depth_hw_full = moge_output_full["depth"] # (moge_inf_h, moge_inf_w) | |
| moge_mask_hw_full = moge_output_full["mask"] # (moge_inf_h, moge_inf_w) | |
| moge_depth_11hw = moge_depth_hw_full.unsqueeze(0).unsqueeze(0) | |
| moge_depth_11hw = torch.nan_to_num(moge_depth_11hw, nan=1e4) | |
| moge_depth_11hw = torch.clamp(moge_depth_11hw, min=0, max=1e4) | |
| moge_mask_11hw = moge_mask_hw_full.unsqueeze(0).unsqueeze(0) | |
| moge_depth_11hw = torch.where(moge_mask_11hw==0, torch.tensor(1000.0, device=moge_depth_11hw.device), moge_depth_11hw) | |
| return moge_depth_11hw, moge_mask_11hw | |
| def demo(args): | |
| """Run video-to-world generation demo. | |
| This function handles the main video-to-world generation pipeline, including: | |
| - Setting up the random seed for reproducibility | |
| - Initializing the generation pipeline with the provided configuration | |
| - Processing single or multiple prompts/images/videos from input | |
| - Generating videos from prompts and images/videos | |
| - Saving the generated videos and corresponding prompts to disk | |
| Args: | |
| cfg (argparse.Namespace): Configuration namespace containing: | |
| - Model configuration (checkpoint paths, model settings) | |
| - Generation parameters (guidance, steps, dimensions) | |
| - Input/output settings (prompts/images/videos, save paths) | |
| - Performance options (model offloading settings) | |
| The function will save: | |
| - Generated MP4 video files | |
| - Text files containing the processed prompts | |
| If guardrails block the generation, a critical log message is displayed | |
| and the function continues to the next prompt if available. | |
| """ | |
| misc.set_random_seed(args.seed) | |
| inference_type = "video2world" | |
| validate_args(args) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if args.num_gpus > 1: | |
| from megatron.core import parallel_state | |
| from cosmos_predict1.utils import distributed | |
| distributed.init() | |
| parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) | |
| process_group = parallel_state.get_context_parallel_group() | |
| # Initialize video2world generation model pipeline | |
| pipeline = Gen3cPipeline( | |
| inference_type=inference_type, | |
| checkpoint_dir=args.checkpoint_dir, | |
| checkpoint_name="Gen3C-Cosmos-7B", | |
| prompt_upsampler_dir=args.prompt_upsampler_dir, | |
| enable_prompt_upsampler=not args.disable_prompt_upsampler, | |
| offload_network=args.offload_diffusion_transformer, | |
| offload_tokenizer=args.offload_tokenizer, | |
| offload_text_encoder_model=args.offload_text_encoder_model, | |
| offload_prompt_upsampler=args.offload_prompt_upsampler, | |
| offload_guardrail_models=args.offload_guardrail_models, | |
| disable_guardrail=args.disable_guardrail, | |
| disable_prompt_encoder=args.disable_prompt_encoder, | |
| guidance=args.guidance, | |
| num_steps=args.num_steps, | |
| height=args.height, | |
| width=args.width, | |
| fps=args.fps, | |
| num_video_frames=121, | |
| seed=args.seed, | |
| ) | |
| frame_buffer_max = pipeline.model.frame_buffer_max | |
| generator = torch.Generator(device=device).manual_seed(args.seed) | |
| sample_n_frames = pipeline.model.chunk_size | |
| # Load the model and assign it to pipeline.model.model | |
| pipeline.model.model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device) | |
| #if args.num_gpus > 1: | |
| # Now pipeline.model.model should be the loaded MoGeModel instance | |
| #pipeline.model.model.enable_context_parallel(process_group) | |
| # Handle multiple prompts if prompt file is provided | |
| if args.batch_input_path: | |
| log.info(f"Reading batch inputs from path: {args.batch_input_path}") | |
| prompts = read_prompts_from_file(args.batch_input_path) | |
| else: | |
| # Single prompt case | |
| prompts = [{"prompt": args.prompt, "visual_input": args.input_image_path}] | |
| os.makedirs(os.path.dirname(args.video_save_folder), exist_ok=True) | |
| for i, input_dict in enumerate(prompts): | |
| current_prompt = input_dict.get("prompt", None) | |
| if current_prompt is None and args.disable_prompt_upsampler: | |
| log.critical("Prompt is missing, skipping world generation.") | |
| continue | |
| current_image_path = input_dict.get("visual_input", None) | |
| if current_image_path is None: | |
| log.critical("Visual input is missing, skipping world generation.") | |
| continue | |
| # Check input frames | |
| if not check_input_frames(current_image_path, 1): | |
| print(f"Input image {current_image_path} is not valid, skipping.") | |
| continue | |
| # load image, predict depth and initialize 3D cache | |
| ( | |
| moge_image_b1chw_float, | |
| moge_depth_b11hw, | |
| moge_mask_b11hw, | |
| moge_initial_w2c_b144, | |
| moge_intrinsics_b133, | |
| ) = _predict_moge_depth( | |
| current_image_path, args.height, args.width, device, pipeline.model.model | |
| ) | |
| cache = Cache3D_Buffer( | |
| frame_buffer_max=frame_buffer_max, | |
| generator=generator, | |
| noise_aug_strength=args.noise_aug_strength, | |
| input_image=moge_image_b1chw_float[:, 0].clone(), # [B, C, H, W] | |
| input_depth=moge_depth_b11hw[:, 0], # [B, 1, H, W] | |
| # input_mask=moge_mask_b11hw[:, 0], # [B, 1, H, W] | |
| input_w2c=moge_initial_w2c_b144[:, 0], # [B, 4, 4] | |
| input_intrinsics=moge_intrinsics_b133[:, 0],# [B, 3, 3] | |
| filter_points_threshold=args.filter_points_threshold, | |
| foreground_masking=args.foreground_masking, | |
| ) | |
| initial_cam_w2c_for_traj = moge_initial_w2c_b144[0, 0] | |
| initial_cam_intrinsics_for_traj = moge_intrinsics_b133[0, 0] | |
| # Generate camera trajectory using the new utility function | |
| try: | |
| generated_w2cs, generated_intrinsics = generate_camera_trajectory( | |
| trajectory_type=args.trajectory, | |
| initial_w2c=initial_cam_w2c_for_traj, | |
| initial_intrinsics=initial_cam_intrinsics_for_traj, | |
| num_frames=args.num_video_frames, | |
| movement_distance=args.movement_distance, | |
| camera_rotation=args.camera_rotation, | |
| center_depth=1.0, | |
| device=device.type, | |
| ) | |
| except (ValueError, NotImplementedError) as e: | |
| log.critical(f"Failed to generate trajectory: {e}") | |
| continue | |
| log.info(f"Generating 0 - {sample_n_frames} frames") | |
| rendered_warp_images, rendered_warp_masks = cache.render_cache( | |
| generated_w2cs[:, 0:sample_n_frames], | |
| generated_intrinsics[:, 0:sample_n_frames], | |
| ) | |
| all_rendered_warps = [] | |
| if args.save_buffer: | |
| all_rendered_warps.append(rendered_warp_images.clone().cpu()) | |
| # Generate video | |
| generated_output = pipeline.generate( | |
| prompt=current_prompt, | |
| image_path=current_image_path, | |
| negative_prompt=args.negative_prompt, | |
| rendered_warp_images=rendered_warp_images, | |
| rendered_warp_masks=rendered_warp_masks, | |
| ) | |
| if generated_output is None: | |
| log.critical("Guardrail blocked video2world generation.") | |
| continue | |
| video, prompt = generated_output | |
| num_ar_iterations = (generated_w2cs.shape[1] - 1) // (sample_n_frames - 1) | |
| for num_iter in range(1, num_ar_iterations): | |
| start_frame_idx = num_iter * (sample_n_frames - 1) # Overlap by 1 frame | |
| end_frame_idx = start_frame_idx + sample_n_frames | |
| log.info(f"Generating {start_frame_idx} - {end_frame_idx} frames") | |
| last_frame_hwc_0_255 = torch.tensor(video[-1], device=device) | |
| pred_image_for_depth_chw_0_1 = last_frame_hwc_0_255.permute(2, 0, 1) / 255.0 # (C,H,W), range [0,1] | |
| pred_depth, pred_mask = _predict_moge_depth_from_tensor( | |
| pred_image_for_depth_chw_0_1, pipeline.model.model | |
| ) | |
| cache.update_cache( | |
| new_image=pred_image_for_depth_chw_0_1.unsqueeze(0) * 2 - 1, # (B,C,H,W) range [-1,1] | |
| new_depth=pred_depth, # (1,1,H,W) | |
| # new_mask=pred_mask, # (1,1,H,W) | |
| new_w2c=generated_w2cs[:, start_frame_idx], | |
| new_intrinsics=generated_intrinsics[:, start_frame_idx], | |
| ) | |
| current_segment_w2cs = generated_w2cs[:, start_frame_idx:end_frame_idx] | |
| current_segment_intrinsics = generated_intrinsics[:, start_frame_idx:end_frame_idx] | |
| rendered_warp_images, rendered_warp_masks = cache.render_cache( | |
| current_segment_w2cs, | |
| current_segment_intrinsics, | |
| ) | |
| if args.save_buffer: | |
| all_rendered_warps.append(rendered_warp_images[:, 1:].clone().cpu()) | |
| pred_image_for_depth_bcthw_minus1_1 = pred_image_for_depth_chw_0_1.unsqueeze(0).unsqueeze(2) * 2 - 1 # (B,C,T,H,W), range [-1,1] | |
| generated_output = pipeline.generate( | |
| prompt=current_prompt, | |
| image_path=pred_image_for_depth_bcthw_minus1_1, | |
| negative_prompt=args.negative_prompt, | |
| rendered_warp_images=rendered_warp_images, | |
| rendered_warp_masks=rendered_warp_masks, | |
| ) | |
| video_new, prompt = generated_output | |
| video = np.concatenate([video, video_new[1:]], axis=0) | |
| # Final video processing | |
| final_video_to_save = video | |
| final_width = args.width | |
| if args.save_buffer and all_rendered_warps: | |
| squeezed_warps = [t.squeeze(0) for t in all_rendered_warps] # Each is (T_chunk, n_i, C, H, W) | |
| if squeezed_warps: | |
| n_max = max(t.shape[1] for t in squeezed_warps) | |
| padded_t_list = [] | |
| for sq_t in squeezed_warps: | |
| # sq_t shape: (T_chunk, n_i, C, H, W) | |
| current_n_i = sq_t.shape[1] | |
| padding_needed_dim1 = n_max - current_n_i | |
| pad_spec = (0,0, # W | |
| 0,0, # H | |
| 0,0, # C | |
| 0,padding_needed_dim1, # n_i | |
| 0,0) # T_chunk | |
| padded_t = F.pad(sq_t, pad_spec, mode='constant', value=-1.0) | |
| padded_t_list.append(padded_t) | |
| full_rendered_warp_tensor = torch.cat(padded_t_list, dim=0) | |
| T_total, _, C_dim, H_dim, W_dim = full_rendered_warp_tensor.shape | |
| buffer_video_TCHnW = full_rendered_warp_tensor.permute(0, 2, 3, 1, 4) | |
| buffer_video_TCHWstacked = buffer_video_TCHnW.contiguous().view(T_total, C_dim, H_dim, n_max * W_dim) | |
| buffer_video_TCHWstacked = (buffer_video_TCHWstacked * 0.5 + 0.5) * 255.0 | |
| buffer_numpy_TCHWstacked = buffer_video_TCHWstacked.cpu().numpy().astype(np.uint8) | |
| buffer_numpy_THWC = np.transpose(buffer_numpy_TCHWstacked, (0, 2, 3, 1)) | |
| final_video_to_save = np.concatenate([buffer_numpy_THWC, final_video_to_save], axis=2) | |
| final_width = args.width * (1 + n_max) | |
| log.info(f"Concatenating video with {n_max} warp buffers. Final video width will be {final_width}") | |
| else: | |
| log.info("No warp buffers to save.") | |
| video_save_path = os.path.join( | |
| args.video_save_folder, | |
| f"{i if args.batch_input_path else args.video_save_name}.mp4" | |
| ) | |
| os.makedirs(os.path.dirname(video_save_path), exist_ok=True) | |
| # Save video | |
| save_video( | |
| video=final_video_to_save, | |
| fps=args.fps, | |
| H=args.height, | |
| W=final_width, | |
| video_save_quality=5, | |
| video_save_path=video_save_path, | |
| ) | |
| log.info(f"Saved video to {video_save_path}") | |
| # clean up properly | |
| if args.num_gpus > 1: | |
| parallel_state.destroy_model_parallel() | |
| import torch.distributed as dist | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| args = parse_arguments() | |
| if args.prompt is None: | |
| args.prompt = "" | |
| args.disable_guardrail = True | |
| args.disable_prompt_upsampler = True | |
| demo(args) | |