|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from collections import defaultdict |
|
from typing import List, Optional, Union |
|
|
|
import cv2 |
|
import einops |
|
import numpy as np |
|
import torch |
|
from tqdm import tqdm |
|
|
|
from cosmos_transfer1.auxiliary.upsampler.model.upsampler import PixtralPromptUpsampler |
|
from cosmos_transfer1.checkpoints import ( |
|
BASE_7B_CHECKPOINT_AV_SAMPLE_PATH, |
|
BASE_7B_CHECKPOINT_PATH, |
|
COSMOS_TOKENIZER_CHECKPOINT, |
|
DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH, |
|
HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH, |
|
VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, |
|
BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, |
|
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH, |
|
SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH, |
|
SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
) |
|
from cosmos_transfer1.diffusion.inference.inference_utils import ( |
|
detect_aspect_ratio, |
|
generate_control_input, |
|
generate_world_from_control, |
|
get_batched_ctrl_batch, |
|
get_ctrl_batch, |
|
get_ctrl_batch_mv, |
|
get_upscale_size, |
|
get_video_batch, |
|
get_video_batch_for_multiview_model, |
|
load_model_by_config, |
|
load_network_model, |
|
load_tokenizer_model, |
|
merge_patches_into_video, |
|
non_strict_load_model, |
|
read_and_resize_input, |
|
read_video_or_image_into_frames_BCTHW, |
|
resize_control_weight_map, |
|
resize_video, |
|
split_video_into_patches, |
|
valid_hint_keys, |
|
) |
|
from cosmos_transfer1.diffusion.model.model_ctrl import ( |
|
VideoDiffusionModelWithCtrl, |
|
VideoDiffusionT2VModelWithCtrl, |
|
VideoDistillModelWithCtrl, |
|
) |
|
from cosmos_transfer1.diffusion.model.model_multi_camera_ctrl import MultiVideoDiffusionModelWithCtrl |
|
from cosmos_transfer1.diffusion.module.parallel import broadcast |
|
from cosmos_transfer1.utils import log |
|
from cosmos_transfer1.utils.base_world_generation_pipeline import BaseWorldGenerationPipeline |
|
from cosmos_transfer1.utils.regional_prompting_utils import prepare_regional_prompts |
|
|
|
MODEL_NAME_DICT = { |
|
BASE_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3", |
|
EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3", |
|
VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_vis_block3", |
|
DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_depth_block3", |
|
KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_keypoint_block3", |
|
SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3", |
|
UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_upscale_block3", |
|
BASE_7B_CHECKPOINT_AV_SAMPLE_PATH: "CTRL_7Bv1pt3_t2v_121frames_control_input_hdmap_block3", |
|
HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_t2v_121frames_control_input_hdmap_block3", |
|
LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_t2v_121frames_control_input_lidar_block3", |
|
BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_hdmap_block3", |
|
BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "CTRL_7Bv1pt3_sv2mv_v2w_57frames_control_input_hdmap_block3", |
|
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_hdmap_block3", |
|
SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_lidar_block3", |
|
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_hdmap_waymo_block3", |
|
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_v2w_57frames_control_input_hdmap_waymo_block3", |
|
EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH: "dev_v2w_ctrl_7bv1pt3_VisControlCanny_video_only_dmd2_fsdp", |
|
} |
|
MODEL_CLASS_DICT = { |
|
BASE_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, |
|
EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, |
|
VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, |
|
DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, |
|
SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, |
|
KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, |
|
UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, |
|
BASE_7B_CHECKPOINT_AV_SAMPLE_PATH: VideoDiffusionT2VModelWithCtrl, |
|
HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionT2VModelWithCtrl, |
|
LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionT2VModelWithCtrl, |
|
BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: MultiVideoDiffusionModelWithCtrl, |
|
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, |
|
SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, |
|
BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: MultiVideoDiffusionModelWithCtrl, |
|
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, |
|
SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, |
|
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, |
|
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, |
|
EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH: VideoDistillModelWithCtrl, |
|
} |
|
|
|
from collections import defaultdict |
|
|
|
|
|
class DiffusionControl2WorldGenerationPipeline(BaseWorldGenerationPipeline): |
|
def __init__( |
|
self, |
|
checkpoint_dir: str, |
|
checkpoint_name: str, |
|
has_text_input: bool = True, |
|
offload_network: bool = False, |
|
offload_tokenizer: bool = False, |
|
offload_text_encoder_model: bool = False, |
|
offload_guardrail_models: bool = False, |
|
guidance: float = 7.0, |
|
num_steps: int = 35, |
|
height: int = 704, |
|
width: int = 1280, |
|
fps: int = 24, |
|
num_video_frames: int = 121, |
|
seed: int = 0, |
|
num_input_frames: int = 1, |
|
control_inputs: dict = None, |
|
sigma_max: float = 70.0, |
|
blur_strength: str = "medium", |
|
canny_threshold: str = "medium", |
|
upsample_prompt: bool = False, |
|
offload_prompt_upsampler: bool = False, |
|
process_group: torch.distributed.ProcessGroup | None = None, |
|
regional_prompts: List[str] = None, |
|
region_definitions: Union[List[List[float]], torch.Tensor] = None, |
|
waymo_example: bool = False, |
|
chunking: bool = False, |
|
): |
|
"""Initialize diffusion world generation pipeline. |
|
|
|
Args: |
|
checkpoint_dir: Base directory containing model checkpoints |
|
checkpoint_name: Name of the diffusion transformer checkpoint to use |
|
has_text_input: Whether the pipeline takes text input for world generation |
|
offload_network: Whether to offload diffusion transformer after inference |
|
offload_tokenizer: Whether to offload tokenizer after inference |
|
offload_text_encoder_model: Whether to offload T5 model after inference |
|
offload_guardrail_models: Whether to offload guardrail models |
|
guidance: Classifier-free guidance scale |
|
num_steps: Number of diffusion sampling steps |
|
height: Height of output video |
|
width: Width of output video |
|
fps: Frames per second of output video |
|
num_video_frames: Number of frames to generate |
|
seed: Random seed for sampling |
|
num_input_frames: Number of latent conditions |
|
control_inputs: Dictionary of control inputs for guided generation |
|
sigma_max: Sigma max for partial denoising |
|
blur_strength: Strength of blur applied to input |
|
canny_threshold: Threshold for edge detection |
|
upsample_prompt: Whether to upsample prompts using prompt upsampler model |
|
offload_prompt_upsampler: Whether to offload prompt upsampler after use |
|
process_group: Process group for distributed training |
|
waymo_example: Whether to use the waymo example post-training checkpoint |
|
chunking: Whether to use the chunking method in generation pipeline |
|
""" |
|
self.num_input_frames = num_input_frames |
|
self.control_inputs = control_inputs |
|
self.sigma_max = sigma_max |
|
self.blur_strength = blur_strength |
|
self.canny_threshold = canny_threshold |
|
self.upsample_prompt = upsample_prompt |
|
self.offload_prompt_upsampler = offload_prompt_upsampler |
|
self.prompt_upsampler = None |
|
self.upsampler_hint_key = None |
|
self.hint_details = None |
|
self.process_group = process_group |
|
self.model_name = MODEL_NAME_DICT[checkpoint_name] |
|
self.model_class = MODEL_CLASS_DICT[checkpoint_name] |
|
self.guidance = guidance |
|
self.num_steps = num_steps |
|
self.height = height |
|
self.width = width |
|
self.fps = fps |
|
self.num_video_frames = num_video_frames |
|
self.seed = seed |
|
self.regional_prompts = regional_prompts |
|
self.region_definitions = region_definitions |
|
self.chunking = chunking |
|
|
|
super().__init__( |
|
checkpoint_dir=checkpoint_dir, |
|
checkpoint_name=checkpoint_name, |
|
has_text_input=has_text_input, |
|
offload_network=offload_network, |
|
offload_tokenizer=offload_tokenizer, |
|
offload_text_encoder_model=offload_text_encoder_model, |
|
offload_guardrail_models=offload_guardrail_models, |
|
) |
|
|
|
|
|
if self.upsample_prompt: |
|
if int(os.environ["RANK"]) == 0: |
|
self._push_torchrun_environ_variables() |
|
self._init_prompt_upsampler() |
|
self._pop_torchrun_environ_variables() |
|
|
|
def _push_torchrun_environ_variables(self): |
|
dist_keys = [ |
|
"RANK", |
|
"LOCAL_RANK", |
|
"WORLD_SIZE", |
|
"LOCAL_WORLD_SIZE", |
|
"GROUP_RANK", |
|
"ROLE_RANK", |
|
"ROLE_NAME", |
|
"OMP_NUM_THREADS", |
|
"MASTER_ADDR", |
|
"MASTER_PORT", |
|
"TORCHELASTIC_USE_AGENT_STORE", |
|
"TORCHELASTIC_MAX_RESTARTS", |
|
"TORCHELASTIC_RUN_ID", |
|
"TORCH_NCCL_ASYNC_ERROR_HANDLING", |
|
"TORCHELASTIC_ERROR_FILE", |
|
] |
|
|
|
self.torchrun_environ_variables = {} |
|
for dist_key in dist_keys: |
|
if dist_key in os.environ: |
|
self.torchrun_environ_variables[dist_key] = os.environ[dist_key] |
|
del os.environ[dist_key] |
|
|
|
def _pop_torchrun_environ_variables(self): |
|
for dist_key in self.torchrun_environ_variables.keys(): |
|
os.environ[dist_key] = self.torchrun_environ_variables[dist_key] |
|
|
|
def _init_prompt_upsampler(self): |
|
""" |
|
Initializes the prompt upsampler based on the provided control inputs. |
|
|
|
Returns: |
|
None: Sets instance variables for prompt upsampler, hint key, and hint details |
|
""" |
|
vis_hint_keys = ["vis", "edge"] |
|
other_hint_keys = ["seg", "depth"] |
|
self.hint_details = None |
|
|
|
log.info("Initializing prompt upsampler...") |
|
|
|
if any(key in vis_hint_keys for key in self.control_inputs): |
|
self.upsampler_hint_key = "vis" |
|
self.hint_details = "vis" if "vis" in self.control_inputs else "edge" |
|
elif any(key in other_hint_keys for key in self.control_inputs): |
|
selected_hint_keys = [key for key in self.control_inputs if key in other_hint_keys] |
|
self.upsampler_hint_key = selected_hint_keys[0] |
|
else: |
|
self.upsampler_hint_key = None |
|
|
|
if self.upsampler_hint_key: |
|
self.prompt_upsampler = PixtralPromptUpsampler( |
|
checkpoint_dir=self.checkpoint_dir, |
|
offload_prompt_upsampler=self.offload_prompt_upsampler, |
|
) |
|
|
|
log.info( |
|
f"Prompt upsampler initialized with hint key: {self.upsampler_hint_key} and hint details: {self.hint_details}" |
|
) |
|
|
|
def _process_prompt_upsampler(self, prompt, video_path, save_folder): |
|
""" |
|
Processes and upscales a given prompt using the prompt upsampler. |
|
|
|
Args: |
|
prompt: The text prompt to upsample |
|
video_path: Path to the input video |
|
save_folder: Folder to save intermediate files |
|
|
|
Returns: |
|
str: The upsampled prompt |
|
""" |
|
if not self.prompt_upsampler: |
|
return prompt |
|
|
|
log.info(f"Upsampling prompt with controlnet: {self.upsampler_hint_key}") |
|
|
|
if self.upsampler_hint_key in ["vis"]: |
|
|
|
if self.control_inputs[self.hint_details].get("input_control", None) is not None: |
|
input_control_path = self.control_inputs[self.hint_details].get("input_control", None) |
|
else: |
|
hint_key = f"control_input_{self.hint_details}" |
|
input_control_path = generate_control_input( |
|
input_file_path=video_path, |
|
save_folder=save_folder, |
|
hint_key=hint_key, |
|
blur_strength=self.blur_strength, |
|
canny_threshold=self.canny_threshold, |
|
) |
|
else: |
|
|
|
input_control_path = self.control_inputs[self.upsampler_hint_key].get("input_control", None) |
|
|
|
prompt = self.prompt_upsampler._prompt_upsample_with_offload(prompt=prompt, video_path=input_control_path) |
|
return prompt |
|
|
|
def _load_model(self): |
|
self.model = load_model_by_config( |
|
config_job_name=self.model_name, |
|
config_file="cosmos_transfer1/diffusion/config/transfer/config.py", |
|
model_class=self.model_class, |
|
base_checkpoint_dir=self.checkpoint_dir, |
|
) |
|
|
|
|
|
def _load_network(self): |
|
|
|
if self.checkpoint_name == "": |
|
load_network_model(self.model, "") |
|
else: |
|
load_network_model(self.model, f"{self.checkpoint_dir}/{self.checkpoint_name}") |
|
if len(self.control_inputs) > 1: |
|
hint_encoders = torch.nn.ModuleList([]) |
|
for key, spec in self.control_inputs.items(): |
|
if key in valid_hint_keys: |
|
model = load_model_by_config( |
|
config_job_name=self.model_name, |
|
config_file="cosmos_transfer1/diffusion/config/transfer/config.py", |
|
model_class=self.model_class, |
|
base_checkpoint_dir=self.checkpoint_dir, |
|
) |
|
log.info(f"Loading ctrl model from ckpt_path: {spec['ckpt_path']}") |
|
load_network_model(model, spec["ckpt_path"]) |
|
hint_encoders.append(model.model.net) |
|
del model |
|
torch.cuda.empty_cache() |
|
self.model.hint_encoders = hint_encoders |
|
else: |
|
for _, spec in self.control_inputs.items(): |
|
log.info(f"Loading ctrl model from ckpt_path: {spec['ckpt_path']}") |
|
|
|
if os.path.exists(spec["ckpt_path"]): |
|
net_state_dict = torch.load(spec["ckpt_path"], map_location="cpu", weights_only=False) |
|
else: |
|
net_state_dict = torch.load( |
|
f"{self.checkpoint_dir}/{spec['ckpt_path']}", map_location="cpu", weights_only=False |
|
) |
|
non_strict_load_model(self.model.model, net_state_dict) |
|
|
|
if self.process_group is not None: |
|
log.info("Enabling CP in base model") |
|
self.model.model.net.enable_context_parallel(self.process_group) |
|
self.model.model.base_model.net.enable_context_parallel(self.process_group) |
|
if hasattr(self.model.model, "hint_encoders"): |
|
log.info("Enabling CP in hint encoders") |
|
self.model.model.hint_encoders.net.enable_context_parallel(self.process_group) |
|
|
|
def _load_tokenizer(self): |
|
load_tokenizer_model(self.model, f"{self.checkpoint_dir}/{COSMOS_TOKENIZER_CHECKPOINT}") |
|
|
|
def _run_tokenizer_decoding(self, sample: torch.Tensor, use_batch: bool = True) -> np.ndarray: |
|
"""Decode latent samples to video frames using the tokenizer decoder. |
|
|
|
Args: |
|
sample: Latent tensor from diffusion model [B, C, T, H, W] |
|
|
|
Returns: |
|
np.ndarray: Decoded video frames as uint8 numpy array [T, H, W, C] |
|
with values in range [0, 255] |
|
""" |
|
|
|
if sample.shape[0] == 1 or use_batch: |
|
video = (1.0 + self.model.decode(sample)).clamp(0, 2) / 2 |
|
else: |
|
|
|
samples = [] |
|
for sample_i in sample: |
|
samples += [self.model.decode(sample_i.unsqueeze(0)).cpu()] |
|
samples = (torch.cat(samples) + 1).clamp(0, 2) / 2 |
|
|
|
|
|
|
|
patch_h, patch_w = samples.shape[-2:] |
|
orig_size = (patch_w, patch_h) |
|
aspect_ratio = detect_aspect_ratio(orig_size) |
|
stitch_w, stitch_h = get_upscale_size(orig_size, aspect_ratio, upscale_factor=3) |
|
n_img_w = (stitch_w - 1) // patch_w + 1 |
|
n_img_h = (stitch_h - 1) // patch_h + 1 |
|
overlap_size_w = overlap_size_h = 0 |
|
if n_img_w > 1: |
|
overlap_size_w = (n_img_w * patch_w - stitch_w) // (n_img_w - 1) |
|
if n_img_h > 1: |
|
overlap_size_h = (n_img_h * patch_h - stitch_h) // (n_img_h - 1) |
|
video = merge_patches_into_video(samples, overlap_size_h, overlap_size_w, n_img_h, n_img_w) |
|
video = torch.nn.functional.interpolate(video[0], size=(patch_h * 3, patch_w * 3), mode="bicubic")[None] |
|
video = video.clamp(0, 1) |
|
video = (video * 255).to(torch.uint8).cpu() |
|
return video |
|
|
|
def _run_model_with_offload( |
|
self, |
|
prompt_embeddings: list[torch.Tensor], |
|
video_paths: list[str], |
|
negative_prompt_embeddings: Optional[list[torch.Tensor]] = None, |
|
control_inputs_list: list[dict] = None, |
|
) -> list[np.ndarray]: |
|
"""Generate world representation with automatic model offloading. |
|
|
|
Wraps the core generation process with model loading/offloading logic |
|
to minimize GPU memory usage during inference. |
|
|
|
Args: |
|
prompt_embeddings: List of text embedding tensors from T5 encoder |
|
video_paths: List of paths to input videos |
|
negative_prompt_embeddings: Optional list of embeddings for negative prompt guidance |
|
control_inputs_list: List of control input dictionaries |
|
|
|
Returns: |
|
list[np.ndarray]: List of generated world representations as numpy arrays |
|
""" |
|
if self.offload_tokenizer: |
|
self._load_tokenizer() |
|
|
|
if self.offload_network: |
|
self._load_network() |
|
|
|
prompt_embeddings = torch.cat(prompt_embeddings) |
|
if negative_prompt_embeddings is not None: |
|
negative_prompt_embeddings = torch.cat(negative_prompt_embeddings) |
|
|
|
samples = self._run_model( |
|
prompt_embeddings=prompt_embeddings, |
|
negative_prompt_embeddings=negative_prompt_embeddings, |
|
video_paths=video_paths, |
|
control_inputs_list=control_inputs_list, |
|
) |
|
|
|
if self.offload_network: |
|
self._offload_network() |
|
|
|
if self.offload_tokenizer: |
|
self._offload_tokenizer() |
|
|
|
return samples |
|
|
|
def _run_model( |
|
self, |
|
prompt_embeddings: torch.Tensor, |
|
video_paths: list[str], |
|
negative_prompt_embeddings: Optional[torch.Tensor] = None, |
|
control_inputs_list: list[dict] = None, |
|
) -> np.ndarray: |
|
""" |
|
Batched world generation with model offloading. |
|
Each batch element corresponds to a (prompt, video, control_inputs) triple. |
|
""" |
|
B = len(video_paths) |
|
assert prompt_embeddings.shape[0] == B, "Batch size mismatch for prompt embeddings" |
|
if negative_prompt_embeddings is not None: |
|
assert negative_prompt_embeddings.shape[0] == B, "Batch size mismatch for negative prompt embeddings" |
|
assert len(control_inputs_list) == B, "Batch size mismatch for control_inputs_list" |
|
|
|
log.info("Starting data augmentation") |
|
|
|
|
|
log.info(f"regional_prompts passed to _run_model: {self.regional_prompts}") |
|
log.info(f"region_definitions passed to _run_model: {self.region_definitions}") |
|
regional_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(self.regional_prompts) |
|
regional_contexts = None |
|
region_masks = None |
|
if self.regional_prompts and self.region_definitions: |
|
|
|
_, regional_contexts, region_masks = prepare_regional_prompts( |
|
model=self.model, |
|
global_prompt=prompt_embeddings, |
|
regional_prompts=regional_embeddings, |
|
region_definitions=self.region_definitions, |
|
batch_size=1, |
|
time_dim=self.num_video_frames, |
|
height=self.height // self.model.tokenizer.spatial_compression_factor, |
|
width=self.width // self.model.tokenizer.spatial_compression_factor, |
|
device=torch.device("cuda"), |
|
compression_factor=self.model.tokenizer.spatial_compression_factor, |
|
) |
|
|
|
is_upscale_case = any("upscale" in control_inputs for control_inputs in control_inputs_list) |
|
|
|
data_batch, state_shape = get_batched_ctrl_batch( |
|
model=self.model, |
|
prompt_embeddings=prompt_embeddings, |
|
negative_prompt_embeddings=negative_prompt_embeddings, |
|
height=self.height, |
|
width=self.width, |
|
fps=self.fps, |
|
num_video_frames=self.num_video_frames, |
|
input_video_paths=video_paths, |
|
control_inputs_list=control_inputs_list, |
|
blur_strength=self.blur_strength, |
|
canny_threshold=self.canny_threshold, |
|
) |
|
|
|
if regional_contexts is not None: |
|
data_batch["regional_contexts"] = regional_contexts |
|
data_batch["region_masks"] = region_masks |
|
|
|
log.info("Completed data augmentation") |
|
|
|
hint_key = data_batch["hint_key"] |
|
control_input = data_batch[hint_key] |
|
input_video = data_batch.get("input_video", None) |
|
control_weight = data_batch.get("control_weight", None) |
|
num_new_generated_frames = self.num_video_frames - self.num_input_frames |
|
B, C, T, H, W = control_input.shape |
|
if (T - self.num_input_frames) % num_new_generated_frames != 0: |
|
pad_t = num_new_generated_frames - ((T - self.num_input_frames) % num_new_generated_frames) |
|
pad_frames = control_input[:, :, -1:].repeat(1, 1, pad_t, 1, 1) |
|
control_input = torch.cat([control_input, pad_frames], dim=2) |
|
if input_video is not None: |
|
pad_video = input_video[:, :, -1:].repeat(1, 1, pad_t, 1, 1) |
|
input_video = torch.cat([input_video, pad_video], dim=2) |
|
num_total_frames_with_padding = control_input.shape[2] |
|
if ( |
|
isinstance(control_weight, torch.Tensor) |
|
and control_weight.ndim > 5 |
|
and num_total_frames_with_padding > control_weight.shape[3] |
|
): |
|
pad_t = num_total_frames_with_padding - control_weight.shape[3] |
|
pad_weight = control_weight[:, :, :, -1:].repeat(1, 1, 1, pad_t, 1, 1) |
|
control_weight = torch.cat([control_weight, pad_weight], dim=3) |
|
else: |
|
num_total_frames_with_padding = T |
|
N_clip = (num_total_frames_with_padding - self.num_input_frames) // num_new_generated_frames |
|
|
|
video = [] |
|
prev_frames = None |
|
for i_clip in tqdm(range(N_clip)): |
|
|
|
data_batch_i = {k: v for k, v in data_batch.items()} |
|
start_frame = num_new_generated_frames * i_clip |
|
end_frame = num_new_generated_frames * (i_clip + 1) + self.num_input_frames |
|
|
|
|
|
if input_video is not None: |
|
if is_upscale_case: |
|
x_sigma_max = [] |
|
for b in range(B): |
|
input_frames = input_video[b : b + 1, :, start_frame:end_frame].cuda() |
|
x0 = self.model.encode(input_frames).contiguous() |
|
print("x0 shape ->", x0.shape) |
|
x_sigma_max.append(self.model.get_x_from_clean(x0, self.sigma_max, seed=(self.seed + i_clip))) |
|
print("x_sigma_max shape ->", x_sigma_max.shape) |
|
x_sigma_max = torch.cat(x_sigma_max) |
|
else: |
|
input_frames = input_video[:, :, start_frame:end_frame].cuda() |
|
x0 = self.model.encode(input_frames).contiguous() |
|
print("x0 shape ->", x0.shape) |
|
x_sigma_max = self.model.get_x_from_clean(x0, self.sigma_max, seed=(self.seed + i_clip)) |
|
print("x_sigma_max shape ->", x_sigma_max.shape) |
|
|
|
else: |
|
x_sigma_max = None |
|
|
|
data_batch_i[hint_key] = control_input[:, :, start_frame:end_frame].cuda() |
|
latent_hint = [] |
|
log.info("Starting latent encoding") |
|
for b in range(B): |
|
data_batch_p = {k: v for k, v in data_batch_i.items()} |
|
data_batch_p[hint_key] = data_batch_i[hint_key][b : b + 1] |
|
if len(control_inputs_list) >= 1 and len(control_inputs_list[0]) > 1: |
|
latent_hint_i = [] |
|
for idx in range(0, data_batch_p[hint_key].size(1), 3): |
|
x_rgb = data_batch_p[hint_key][:, idx : idx + 3] |
|
latent_hint_i.append(self.model.encode(x_rgb)) |
|
latent_hint.append(torch.cat(latent_hint_i).unsqueeze(0)) |
|
else: |
|
latent_hint.append(self.model.encode_latent(data_batch_p)) |
|
data_batch_i["latent_hint"] = latent_hint = torch.cat(latent_hint) |
|
log.info("Completed latent encoding") |
|
|
|
if isinstance(control_weight, torch.Tensor) and control_weight.ndim > 4: |
|
control_weight_t = control_weight[..., start_frame:end_frame, :, :] |
|
t, h, w = latent_hint.shape[-3:] |
|
data_batch_i["control_weight"] = resize_control_weight_map(control_weight_t, (t, h // 2, w // 2)) |
|
|
|
|
|
if i_clip == 0: |
|
num_input_frames = 0 |
|
latent_tmp = latent_hint if latent_hint.ndim == 5 else latent_hint[:, 0] |
|
condition_latent = torch.zeros_like(latent_tmp) |
|
else: |
|
num_input_frames = self.num_input_frames |
|
prev_frames = split_video_into_patches(prev_frames, control_input.shape[-2], control_input.shape[-1]) |
|
input_frames = prev_frames.bfloat16().cuda() / 255.0 * 2 - 1 |
|
condition_latent = self.model.encode(input_frames).contiguous() |
|
|
|
|
|
log.info("Starting diffusion sampling") |
|
latents = generate_world_from_control( |
|
model=self.model, |
|
state_shape=state_shape, |
|
is_negative_prompt=True, |
|
data_batch=data_batch_i, |
|
guidance=self.guidance, |
|
num_steps=self.num_steps, |
|
seed=(self.seed + i_clip), |
|
condition_latent=condition_latent, |
|
num_input_frames=num_input_frames, |
|
sigma_max=self.sigma_max if x_sigma_max is not None else None, |
|
x_sigma_max=x_sigma_max, |
|
use_batch_processing=False if is_upscale_case else True, |
|
chunking=self.chunking, |
|
) |
|
log.info("Completed diffusion sampling") |
|
log.info("Starting VAE decode") |
|
frames = self._run_tokenizer_decoding( |
|
latents, use_batch=False if is_upscale_case else True |
|
) |
|
log.info("Completed VAE decode") |
|
|
|
if i_clip == 0: |
|
video.append(frames) |
|
else: |
|
video.append(frames[:, :, self.num_input_frames :]) |
|
|
|
prev_frames = torch.zeros_like(frames) |
|
prev_frames[:, :, : self.num_input_frames] = frames[:, :, -self.num_input_frames :] |
|
|
|
video = torch.cat(video, dim=2)[:, :, :T] |
|
video = video.permute(0, 2, 3, 4, 1).numpy() |
|
return video |
|
|
|
def generate( |
|
self, |
|
prompt: str | list[str], |
|
video_path: str | list[str], |
|
negative_prompt: Optional[str | list[str]] = None, |
|
control_inputs: dict | list[dict] = None, |
|
save_folder: str = "outputs/", |
|
batch_size: int = 1, |
|
) -> tuple[np.ndarray, str | list[str]] | None: |
|
"""Generate video from text prompt and control video. |
|
|
|
Pipeline steps: |
|
1. Run safety checks on input prompt |
|
2. Convert prompt to embeddings |
|
3. Generate video frames using diffusion |
|
4. Run safety checks and apply face blur on generated video frames |
|
|
|
Args: |
|
prompt: Text description of desired video |
|
video_path: Path to input video |
|
negative_prompt: Optional text to guide what not to generate |
|
control_inputs: Control inputs for guided generation |
|
save_folder: Folder to save intermediate files |
|
batch_size: Number of videos to process simultaneously |
|
|
|
Returns: |
|
tuple: ( |
|
Generated video frames as uint8 np.ndarray [T, H, W, C], |
|
Final prompt used for generation (may be enhanced) |
|
), or None if content fails guardrail safety checks |
|
""" |
|
|
|
|
|
|
|
|
|
prompts = [prompt] if isinstance(prompt, str) else prompt |
|
video_paths = [video_path] if isinstance(video_path, str) else video_path |
|
control_inputs_list = [control_inputs] if not isinstance(control_inputs, list) else control_inputs |
|
|
|
assert len(video_paths) == batch_size, "Number of prompts and videos must match" |
|
assert len(control_inputs_list) == batch_size, "Number of control inputs must match batch size" |
|
log.info(f"Running batch generation with batch_size={batch_size}") |
|
|
|
|
|
all_videos = [] |
|
all_final_prompts = [] |
|
|
|
|
|
if self.prompt_upsampler and int(os.environ["RANK"]) == 0: |
|
self._push_torchrun_environ_variables() |
|
upsampled_prompts = [] |
|
for i, (single_prompt, single_video_path) in enumerate(zip(prompts, video_paths)): |
|
log.info(f"Upsampling prompt {i+1}/{batch_size}: {single_prompt[:50]}...") |
|
video_save_subfolder = os.path.join(save_folder, f"video_{i}") |
|
os.makedirs(video_save_subfolder, exist_ok=True) |
|
upsampled_prompt = self._process_prompt_upsampler( |
|
single_prompt, single_video_path, video_save_subfolder |
|
) |
|
upsampled_prompts.append(upsampled_prompt) |
|
log.info(f"Upsampled prompt {i+1}: {upsampled_prompt[:50]}...") |
|
self._pop_torchrun_environ_variables() |
|
prompts = upsampled_prompts |
|
|
|
log.info("Running guardrail checks on all prompts") |
|
safe_indices = [] |
|
for i, single_prompt in enumerate(prompts): |
|
is_safe = self._run_guardrail_on_prompt_with_offload(single_prompt) |
|
if is_safe: |
|
safe_indices.append(i) |
|
else: |
|
log.critical(f"Input text prompt {i+1} is not safe") |
|
|
|
if not safe_indices: |
|
log.critical("All prompts failed safety checks") |
|
return None |
|
|
|
safe_prompts = [prompts[i] for i in safe_indices] |
|
safe_video_paths = [video_paths[i] for i in safe_indices] |
|
safe_control_inputs = [control_inputs_list[i] for i in safe_indices] |
|
|
|
log.info("Running text embedding on all prompts") |
|
all_prompt_embeddings = [] |
|
|
|
|
|
embedding_batch_size = min(batch_size, 8) |
|
for i in range(0, len(safe_prompts), embedding_batch_size): |
|
batch_prompts = safe_prompts[i : i + embedding_batch_size] |
|
if negative_prompt: |
|
batch_prompts_with_neg = [] |
|
for p in batch_prompts: |
|
batch_prompts_with_neg.extend([p, negative_prompt]) |
|
else: |
|
batch_prompts_with_neg = batch_prompts |
|
log.info("Starting T5 compute") |
|
prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(batch_prompts_with_neg) |
|
log.info("Completed T5 compute") |
|
|
|
if negative_prompt: |
|
for j in range(0, len(prompt_embeddings), 2): |
|
all_prompt_embeddings.append((prompt_embeddings[j], prompt_embeddings[j + 1])) |
|
else: |
|
for emb in prompt_embeddings: |
|
all_prompt_embeddings.append((emb, None)) |
|
log.info("Finish text embedding on prompt") |
|
|
|
|
|
log.info("Run generation") |
|
|
|
all_neg_embeddings = [emb[1] for emb in all_prompt_embeddings] |
|
all_prompt_embeddings = [emb[0] for emb in all_prompt_embeddings] |
|
videos = self._run_model_with_offload( |
|
prompt_embeddings=all_prompt_embeddings, |
|
negative_prompt_embeddings=all_neg_embeddings, |
|
video_paths=safe_video_paths, |
|
control_inputs_list=safe_control_inputs, |
|
) |
|
log.info("Finish generation") |
|
|
|
log.info("Run guardrail on generated videos") |
|
for i, video in enumerate(videos): |
|
safe_video = self._run_guardrail_on_video_with_offload(video) |
|
if safe_video is not None: |
|
all_videos.append(safe_video) |
|
all_final_prompts.append(safe_prompts[i]) |
|
else: |
|
log.critical(f"Generated video {i+1} is not safe") |
|
if not all_videos: |
|
log.critical("All generated videos failed safety checks") |
|
return None |
|
return all_videos, all_final_prompts |
|
|
|
|
|
class DiffusionControl2WorldMultiviewGenerationPipeline(DiffusionControl2WorldGenerationPipeline): |
|
def __init__(self, *args, is_lvg_model=False, n_clip_max=-1, **kwargs): |
|
super(DiffusionControl2WorldMultiviewGenerationPipeline, self).__init__(*args, **kwargs) |
|
self.is_lvg_model = is_lvg_model |
|
self.n_clip_max = n_clip_max |
|
|
|
def _run_tokenizer_decoding(self, sample: torch.Tensor): |
|
"""Decode latent samples to video frames using the tokenizer decoder. |
|
|
|
Args: |
|
sample: Latent tensor from diffusion model [B, C, T, H, W] |
|
|
|
Returns: |
|
np.ndarray: Decoded video frames as uint8 numpy array [T, H, W, C] |
|
with values in range [0, 255] |
|
""" |
|
|
|
if self.model.n_views == 5: |
|
video_arrangement = [1, 0, 2, 3, 0, 4] |
|
elif self.model.n_views == 6: |
|
video_arrangement = [1, 0, 2, 4, 3, 5] |
|
else: |
|
raise ValueError(f"Unsupported number of views: {self.model.n_views}") |
|
|
|
video = (1.0 + self.model.decode(sample)).clamp(0, 2) / 2 |
|
video_segments = einops.rearrange(video, "b c (v t) h w -> b c v t h w", v=self.model.n_views) |
|
grid_video = torch.stack( |
|
[video_segments[:, :, i] for i in video_arrangement], |
|
dim=2, |
|
) |
|
grid_video = einops.rearrange(grid_video, "b c (h w) t h1 w1 -> b c t (h h1) (w w1)", h=2, w=3) |
|
grid_video = (grid_video[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy() |
|
video = (video[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy() |
|
|
|
return [grid_video, video] |
|
|
|
def _run_model_with_offload( |
|
self, |
|
prompt_embedding: torch.Tensor, |
|
view_condition_video="", |
|
initial_condition_video="", |
|
control_inputs: dict = None, |
|
) -> np.ndarray: |
|
"""Generate world representation with automatic model offloading. |
|
|
|
Wraps the core generation process with model loading/offloading logic |
|
to minimize GPU memory usage during inference. |
|
|
|
Args: |
|
prompt_embedding: Text embedding tensor from T5 encoder |
|
view_condition_video: Path to input sv view condition video |
|
initial_condition_video: Path to input mv initial frames |
|
control_inputs: Dictionary of control modalities and corresponding inputs |
|
|
|
Returns: |
|
np.ndarray: Generated world representation as numpy array |
|
""" |
|
if self.offload_tokenizer: |
|
self._load_tokenizer() |
|
|
|
if self.offload_network: |
|
self._load_network() |
|
|
|
sample = self._run_model( |
|
prompt_embedding, view_condition_video, initial_condition_video, control_inputs=control_inputs |
|
) |
|
|
|
if self.offload_network: |
|
self._offload_network() |
|
|
|
if self.offload_tokenizer: |
|
self._offload_tokenizer() |
|
|
|
return sample |
|
|
|
def _run_model( |
|
self, |
|
embedding: torch.Tensor, |
|
view_condition_video="", |
|
initial_condition_video="", |
|
control_inputs: dict = None, |
|
) -> torch.Tensor: |
|
"""Generate video frames using the diffusion model. |
|
|
|
Args: |
|
prompt_embedding: Text embedding tensor from T5 encoder |
|
view_condition_video: Path to input sv view condition video |
|
initial_condition_video: Path to input mv initial frames |
|
control_inputs: Dictionary of control modalities and corresponding inputs |
|
|
|
Returns: |
|
Tensor of generated video frames |
|
|
|
Note: |
|
Model and tokenizer are automatically offloaded after inference |
|
if offloading is enabled. |
|
""" |
|
|
|
assert len(embedding) == self.model.n_views |
|
|
|
view_condition_video, fps = read_video_or_image_into_frames_BCTHW( |
|
view_condition_video, |
|
normalize=False, |
|
max_frames=6000, |
|
also_return_fps=True, |
|
) |
|
view_condition_video = resize_video( |
|
view_condition_video, self.height, self.width, interpolation=cv2.INTER_LINEAR |
|
) |
|
view_condition_video = torch.from_numpy(view_condition_video) |
|
total_T = view_condition_video.shape[2] |
|
|
|
data_batch, state_shape = get_video_batch_for_multiview_model( |
|
model=self.model, |
|
prompt_embedding=embedding, |
|
height=self.height, |
|
width=self.width, |
|
fps=self.fps, |
|
num_video_frames=self.num_video_frames * len(embedding), |
|
frame_repeat_negative_condition=0, |
|
) |
|
|
|
self.model.condition_location = "first_cam_and_first_n" if self.is_lvg_model else "first_cam" |
|
|
|
if self.is_lvg_model: |
|
if os.path.isdir(initial_condition_video): |
|
initial_condition_videos = [] |
|
fnames = sorted(os.listdir(initial_condition_video)) |
|
for fname in fnames: |
|
if fname.endswith(".mp4"): |
|
try: |
|
input_view_id = int(fname.split(".")[0]) |
|
except ValueError: |
|
log.warning(f"Could not parse video file name {fname} into view id") |
|
continue |
|
initial_condition_video_n = read_video_or_image_into_frames_BCTHW( |
|
fname, |
|
normalize=False, |
|
max_frames=self.num_input_frames, |
|
also_return_fps=True, |
|
) |
|
initial_condition_videos.append(torch.from_numpy(initial_condition_video_n)) |
|
initial_condition_video = torch.cat(initial_condition_videos, dim=2) |
|
else: |
|
initial_condition_video, _ = read_video_or_image_into_frames_BCTHW( |
|
initial_condition_video, |
|
normalize=False, |
|
max_frames=6000, |
|
also_return_fps=True, |
|
) |
|
initial_condition_video = torch.from_numpy(initial_condition_video) |
|
else: |
|
initial_condition_video = None |
|
|
|
data_batch = get_ctrl_batch_mv( |
|
self.height, self.width, data_batch, total_T, control_inputs, self.model.n_views, self.num_video_frames |
|
) |
|
|
|
hint_key = data_batch["hint_key"] |
|
input_video = None |
|
control_input = data_batch[hint_key] |
|
control_weight = data_batch["control_weight"] |
|
|
|
num_new_generated_frames = self.num_video_frames - self.num_input_frames |
|
B, C, T, H, W = control_input.shape |
|
T = T // self.model.n_views |
|
assert T == total_T |
|
|
|
|
|
|
|
if self.is_lvg_model: |
|
N_clip = (T - self.num_input_frames) // num_new_generated_frames |
|
if self.n_clip_max > 0: |
|
N_clip = min(self.n_clip_max, N_clip) |
|
else: |
|
N_clip = 1 |
|
log.info("Model is not Long-video generation model, overwriting N_clip to 1") |
|
|
|
video = [] |
|
for i_clip in tqdm(range(N_clip)): |
|
data_batch_i = {k: v for k, v in data_batch.items()} |
|
start_frame = num_new_generated_frames * i_clip |
|
end_frame = num_new_generated_frames * (i_clip + 1) + self.num_input_frames |
|
|
|
if input_video is not None: |
|
x_sigma_max = [] |
|
for b in range(B): |
|
input_frames = input_video[b : b + 1, :, start_frame:end_frame].cuda() |
|
x0 = self.model.encode(input_frames).contiguous() |
|
x_sigma_max.append(self.model.get_x_from_clean(x0, self.sigma_max, seed=(self.seed + i_clip))) |
|
x_sigma_max = torch.cat(x_sigma_max) |
|
else: |
|
x_sigma_max = None |
|
|
|
control_input_BVCT = einops.rearrange(control_input, "B C (V T) H W -> (B V) C T H W", V=self.model.n_views) |
|
control_input_i = control_input_BVCT[:, :, start_frame:end_frame].cuda() |
|
|
|
data_batch_i[hint_key] = einops.rearrange( |
|
control_input_i, "(B V) C T H W -> B C (V T) H W", V=self.model.n_views |
|
) |
|
|
|
condition_input_i = view_condition_video[:, :, start_frame:end_frame].cuda() |
|
|
|
latent_hint = [] |
|
for b in range(B): |
|
data_batch_p = {k: v for k, v in data_batch_i.items()} |
|
data_batch_p[hint_key] = data_batch_i[hint_key][b : b + 1] |
|
if len(control_inputs) > 1: |
|
latent_hint_i = [] |
|
for idx in range(0, data_batch_p[hint_key].size(1), 3): |
|
x_rgb = data_batch_p[hint_key][:, idx : idx + 3] |
|
latent_hint_i.append(self.model.encode(x_rgb)) |
|
latent_hint.append(torch.cat(latent_hint_i).unsqueeze(0)) |
|
else: |
|
latent_hint.append(self.model.encode_latent(data_batch_p)) |
|
data_batch_i["latent_hint"] = latent_hint = torch.cat(latent_hint) |
|
|
|
if "regional_contexts" in data_batch_i: |
|
data_batch_i["regional_contexts"] = broadcast(data_batch_i["regional_contexts"], to_tp=True, to_cp=True) |
|
data_batch_i["region_masks"] = broadcast(data_batch_i["region_masks"], to_tp=True, to_cp=True) |
|
|
|
if isinstance(control_weight, torch.Tensor) and control_weight.ndim > 4: |
|
control_weight_t = control_weight[..., start_frame:end_frame, :, :].cuda() |
|
t, h, w = latent_hint.shape[-3:] |
|
data_batch_i["control_weight"] = resize_control_weight_map(control_weight_t, (t, h // 2, w // 2)) |
|
|
|
if i_clip == 0: |
|
if initial_condition_video is not None: |
|
prev_frames_blank = torch.zeros((B, self.model.n_views, C, self.num_video_frames, H, W)).to( |
|
view_condition_video |
|
) |
|
|
|
initial_condition_video_frames_BVCT = einops.rearrange( |
|
initial_condition_video, "B C (V T) H W -> B V C T H W", V=self.model.n_views |
|
) |
|
prev_frames_blank[:, :, :, : self.num_input_frames] = initial_condition_video_frames_BVCT[ |
|
:, :, :, start_frame : start_frame + self.num_input_frames |
|
].cuda() |
|
prev_frames = einops.rearrange(prev_frames_blank, "B V C T H W -> B C (V T) H W") |
|
num_input_frames = self.num_input_frames |
|
else: |
|
num_input_frames = 0 |
|
prev_frames = None |
|
else: |
|
num_input_frames = self.num_input_frames |
|
condition_latent = self.get_condition_latent( |
|
state_shape, |
|
data_batch_i, |
|
cond_video=condition_input_i, |
|
prev_frames=prev_frames, |
|
patch_h=H, |
|
patch_w=W, |
|
skip_reencode=False, |
|
).bfloat16() |
|
|
|
latents = generate_world_from_control( |
|
model=self.model, |
|
state_shape=self.model.state_shape, |
|
is_negative_prompt=False, |
|
data_batch=data_batch_i, |
|
guidance=self.guidance, |
|
num_steps=self.num_steps, |
|
seed=(self.seed + i_clip), |
|
condition_latent=condition_latent, |
|
num_input_frames=num_input_frames, |
|
sigma_max=self.sigma_max if x_sigma_max is not None else None, |
|
x_sigma_max=x_sigma_max, |
|
augment_sigma=0.0, |
|
) |
|
torch.cuda.empty_cache() |
|
_, frames = self._run_tokenizer_decoding(latents) |
|
frames = torch.from_numpy(frames).permute(3, 0, 1, 2)[None] |
|
frames_BVCT = einops.rearrange(frames, "B C (V T) H W -> B V C T H W", V=self.model.n_views) |
|
if i_clip == 0: |
|
video.append(frames_BVCT) |
|
else: |
|
frames_BVCT_non_overlap = frames_BVCT[:, :, :, num_input_frames:] |
|
video.append(frames_BVCT_non_overlap) |
|
|
|
prev_frames = torch.zeros_like(frames_BVCT) |
|
n_copy = max(1, abs(self.num_input_frames)) |
|
prev_frames[:, :, :, :n_copy] = frames_BVCT[:, :, :, -n_copy:] |
|
prev_frames = einops.rearrange(prev_frames, "B V C T H W -> B C (V T) H W") |
|
|
|
video = torch.cat(video, dim=3) |
|
video = einops.rearrange(video, "B V C T H W -> B C (V T) H W") |
|
video = video[0].permute(1, 2, 3, 0).numpy() |
|
return video |
|
|
|
def get_condition_latent( |
|
self, |
|
state_shape, |
|
data_batch_i, |
|
cond_video=None, |
|
prev_frames=None, |
|
patch_h=1024, |
|
patch_w=1024, |
|
skip_reencode=False, |
|
prev_latents=None, |
|
): |
|
""" |
|
Create the condition latent used in this loop for generation from RGB frames |
|
Args: |
|
model: |
|
state_shape: tuple (C T H W), shape of latent to be generated |
|
data_batch_i: (dict) this is only used to get batch size |
|
multi_cam: (bool) whether to use multicam processing or revert to original behavior from tpsp_demo |
|
cond_video: (tensor) the front view video for conditioning sv2mv |
|
prev_frames: (tensor) frames generated in previous loop |
|
patch_h: (int) |
|
patch_w: (int) |
|
skip_reencode: (bool) whether to use the tokenizer to encode prev_frames, or read from prev_latents directly |
|
prev_latents: (tensor) latent generated in previous loop, must not be None if skip_reencode |
|
|
|
Returns: |
|
|
|
""" |
|
|
|
B = data_batch_i["video"].shape[0] |
|
|
|
latent_sample = torch.zeros(state_shape).unsqueeze(0).repeat(B, 1, 1, 1, 1).cuda() |
|
latent_sample = einops.rearrange(latent_sample, "B C (V T) H W -> B V C T H W", V=self.model.n_views) |
|
log.info(f"model.sigma_data {self.model.sigma_data}") |
|
if self.model.config.conditioner.video_cond_bool.condition_location.endswith("first_n"): |
|
if skip_reencode: |
|
assert prev_latents is not None |
|
prev_latents = einops.rearrange(prev_latents, "B C (V T) H W -> B V C T H W", V=self.model.n_views) |
|
latent_sample = prev_latents.clone() |
|
else: |
|
prev_frames = split_video_into_patches(prev_frames, patch_h, patch_w) |
|
for b in range(prev_frames.shape[0]): |
|
input_frames = prev_frames[b : b + 1].cuda() / 255.0 * 2 - 1 |
|
input_frames = einops.rearrange(input_frames, "1 C (V T) H W -> V C T H W", V=self.model.n_views) |
|
encoded_frames = self.model.tokenizer.encode(input_frames).contiguous() * self.model.sigma_data |
|
latent_sample[b : b + 1, :] = encoded_frames |
|
|
|
if self.model.config.conditioner.video_cond_bool.condition_location.startswith("first_cam"): |
|
assert cond_video is not None |
|
cond_video = split_video_into_patches(cond_video, patch_h, patch_w) |
|
for b in range(cond_video.shape[0]): |
|
input_frames = cond_video[b : b + 1].cuda() / 255.0 * 2 - 1 |
|
|
|
latent_sample[ |
|
b : b + 1, |
|
0, |
|
] = ( |
|
self.model.tokenizer.encode(input_frames).contiguous() * self.model.sigma_data |
|
) |
|
|
|
latent_sample = einops.rearrange(latent_sample, " B V C T H W -> B C (V T) H W") |
|
log.info(f"latent_sample, {latent_sample[:,0,:,0,0]}") |
|
|
|
return latent_sample |
|
|
|
def build_mv_prompt(self, mv_prompts, n_views): |
|
""" |
|
Apply multiview prompt formatting to the input prompt such that hte text conditioning matches that used during |
|
training. |
|
Args: |
|
prompt: caption of one scene, with prompt of each view separated by ";" |
|
n_views: number of cameras to format the caption to |
|
|
|
Returns: |
|
|
|
""" |
|
if n_views == 5: |
|
base_prompts = [ |
|
"The video is captured from a camera mounted on a car. The camera is facing forward.", |
|
"The video is captured from a camera mounted on a car. The camera is facing to the left.", |
|
"The video is captured from a camera mounted on a car. The camera is facing to the right.", |
|
"The video is captured from a camera mounted on a car. The camera is facing the rear left side.", |
|
"The video is captured from a camera mounted on a car. The camera is facing the rear right side.", |
|
] |
|
elif n_views == 6: |
|
base_prompts = [ |
|
"The video is captured from a camera mounted on a car. The camera is facing forward.", |
|
"The video is captured from a camera mounted on a car. The camera is facing to the left.", |
|
"The video is captured from a camera mounted on a car. The camera is facing to the right.", |
|
"The video is captured from a camera mounted on a car. The camera is facing backwards.", |
|
"The video is captured from a camera mounted on a car. The camera is facing the rear left side.", |
|
"The video is captured from a camera mounted on a car. The camera is facing the rear right side.", |
|
] |
|
|
|
log.info(f"Reading multiview prompts, found {len(mv_prompts)} splits") |
|
n = len(mv_prompts) |
|
if n < n_views: |
|
mv_prompts += base_prompts[n:] |
|
else: |
|
mv_prompts = mv_prompts[:n_views] |
|
|
|
for vid, p in enumerate(mv_prompts): |
|
if not p.startswith(base_prompts[vid]): |
|
mv_prompts[vid] = base_prompts[vid] + " " + p |
|
log.info(f"Adding missing camera caption to view {vid}, {p[:30]}") |
|
|
|
log.info(f"Procced multiview prompts, {len(mv_prompts)} splits") |
|
return mv_prompts |
|
|
|
def generate( |
|
self, |
|
prompts: list, |
|
view_condition_video: str, |
|
initial_condition_video: str, |
|
control_inputs: dict = None, |
|
save_folder: str = "outputs/", |
|
) -> tuple[np.ndarray, str] | None: |
|
"""Generate video from text prompt and control video. |
|
|
|
Pipeline steps: |
|
1. Run safety checks on input prompt |
|
2. Convert prompt to embeddings |
|
3. Generate video frames using diffusion |
|
4. Run safety checks and apply face blur on generated video frames |
|
|
|
Args: |
|
control_inputs: Control inputs for guided generation |
|
save_folder: Folder to save intermediate files |
|
|
|
Returns: |
|
tuple: ( |
|
Generated video frames as uint8 np.ndarray [T, H, W, C], |
|
Final prompt used for generation (may be enhanced) |
|
), or None if content fails guardrail safety checks |
|
""" |
|
|
|
log.info(f"Run with view condition video path: {view_condition_video}") |
|
if initial_condition_video: |
|
log.info(f"Run with initial condition video path: {initial_condition_video}") |
|
mv_prompts = self.build_mv_prompt(prompts, self.model.n_views) |
|
log.info(f"Run with prompt: {mv_prompts}") |
|
|
|
|
|
log.info("Run guardrail on prompt") |
|
is_safe = self._run_guardrail_on_prompt_with_offload(". ".join(mv_prompts)) |
|
if not is_safe: |
|
log.critical("Input text prompt is not safe") |
|
return None |
|
log.info("Pass guardrail on prompt") |
|
|
|
prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(mv_prompts) |
|
prompt_embedding = torch.concat(prompt_embeddings, dim=0).cuda() |
|
|
|
log.info("Finish text embedding on prompt") |
|
|
|
|
|
log.info("Run generation") |
|
|
|
video = self._run_model_with_offload( |
|
prompt_embedding, |
|
view_condition_video, |
|
initial_condition_video, |
|
control_inputs=control_inputs, |
|
) |
|
log.info("Finish generation") |
|
log.info("Run guardrail on generated video") |
|
video = self._run_guardrail_on_video_with_offload(video) |
|
if video is None: |
|
log.critical("Generated video is not safe") |
|
raise ValueError("Guardrail check failed: Generated video is unsafe") |
|
|
|
log.info("Pass guardrail on generated video") |
|
|
|
return video, mv_prompts |
|
|
|
|
|
class DistilledControl2WorldGenerationPipeline(DiffusionControl2WorldGenerationPipeline): |
|
"""Pipeline for distilled ControlNet video2video inference.""" |
|
|
|
def _load_network(self): |
|
log.info("Loading distilled consolidated checkpoint") |
|
|
|
|
|
from cosmos_transfer1.diffusion.inference.inference_utils import skip_init_linear |
|
|
|
with skip_init_linear(): |
|
self.model.set_up_model() |
|
checkpoint_path = f"{self.checkpoint_dir}/{self.checkpoint_name}" |
|
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
|
state_dict = checkpoint.get("model", checkpoint) |
|
|
|
|
|
base_state_dict = {} |
|
ctrl_state_dict = {} |
|
|
|
for k, v in state_dict.items(): |
|
if k.startswith("net.base_model.net."): |
|
base_key = k[len("net.base_model.net.") :] |
|
base_state_dict[base_key] = v |
|
elif k.startswith("net.net_ctrl."): |
|
ctrl_key = k[len("net.net_ctrl.") :] |
|
ctrl_state_dict[ctrl_key] = v |
|
|
|
|
|
if base_state_dict: |
|
self.model.model["net"].base_model.net.load_state_dict(base_state_dict, strict=False) |
|
self.model.model.base_model.load_state_dict(base_state_dict, strict=False) |
|
|
|
if ctrl_state_dict: |
|
self.model.model["net"].net_ctrl.load_state_dict(ctrl_state_dict, strict=False) |
|
self.model.model.cuda() |
|
|
|
if self.process_group is not None: |
|
log.info("Enabling CP in base model") |
|
self.model.model.net.enable_context_parallel(self.process_group) |
|
|
|
def _run_model( |
|
self, |
|
prompt_embeddings: torch.Tensor, |
|
video_paths: list[str], |
|
negative_prompt_embeddings: Optional[torch.Tensor] = None, |
|
control_inputs_list: list[dict] = None, |
|
) -> np.ndarray: |
|
""" |
|
Batched world generation with model offloading. |
|
Each batch element corresponds to a (prompt, video, control_inputs) triple. |
|
""" |
|
B = len(video_paths) |
|
print(f"video paths: {video_paths}") |
|
assert prompt_embeddings.shape[0] == B, "Batch size mismatch for prompt embeddings" |
|
if negative_prompt_embeddings is not None: |
|
assert negative_prompt_embeddings.shape[0] == B, "Batch size mismatch for negative prompt embeddings" |
|
assert len(control_inputs_list) == B, "Batch size mismatch for control_inputs_list" |
|
|
|
log.info("Starting data augmentation") |
|
|
|
log.info(f"Regional prompts not supported when using distilled model, dropping: {self.regional_prompts}") |
|
|
|
|
|
data_batch, state_shape = get_batched_ctrl_batch( |
|
model=self.model, |
|
prompt_embeddings=prompt_embeddings, |
|
negative_prompt_embeddings=negative_prompt_embeddings, |
|
height=self.height, |
|
width=self.width, |
|
fps=self.fps, |
|
num_video_frames=self.num_video_frames, |
|
input_video_paths=video_paths, |
|
control_inputs_list=control_inputs_list, |
|
blur_strength=self.blur_strength, |
|
canny_threshold=self.canny_threshold, |
|
) |
|
|
|
log.info("Completed data augmentation") |
|
|
|
hint_key = data_batch["hint_key"] |
|
control_input = data_batch[hint_key] |
|
input_video = data_batch.get("input_video", None) |
|
control_weight = data_batch.get("control_weight", None) |
|
num_new_generated_frames = self.num_video_frames - self.num_input_frames |
|
B, C, T, H, W = control_input.shape |
|
if (T - self.num_input_frames) % num_new_generated_frames != 0: |
|
pad_t = num_new_generated_frames - ((T - self.num_input_frames) % num_new_generated_frames) |
|
pad_frames = control_input[:, :, -1:].repeat(1, 1, pad_t, 1, 1) |
|
control_input = torch.cat([control_input, pad_frames], dim=2) |
|
if input_video is not None: |
|
pad_video = input_video[:, :, -1:].repeat(1, 1, pad_t, 1, 1) |
|
input_video = torch.cat([input_video, pad_video], dim=2) |
|
num_total_frames_with_padding = control_input.shape[2] |
|
if ( |
|
isinstance(control_weight, torch.Tensor) |
|
and control_weight.ndim > 5 |
|
and num_total_frames_with_padding > control_weight.shape[3] |
|
): |
|
pad_t = num_total_frames_with_padding - control_weight.shape[3] |
|
pad_weight = control_weight[:, :, :, -1:].repeat(1, 1, 1, pad_t, 1, 1) |
|
control_weight = torch.cat([control_weight, pad_weight], dim=3) |
|
else: |
|
num_total_frames_with_padding = T |
|
N_clip = (num_total_frames_with_padding - self.num_input_frames) // num_new_generated_frames |
|
|
|
video = [] |
|
initial_condition_input = None |
|
|
|
prev_frames = None |
|
if input_video is not None: |
|
prev_frames = torch.zeros_like(input_video).cuda() |
|
prev_frames[:, :, : self.num_input_frames] = (input_video[:, :, : self.num_input_frames] + 1) * 255.0 / 2 |
|
log.info(f"N_clip: {N_clip}") |
|
for i_clip in tqdm(range(N_clip)): |
|
log.info(f"input_video shape: {input_video.shape}") |
|
|
|
data_batch_i = {k: v for k, v in data_batch.items()} |
|
start_frame = num_new_generated_frames * i_clip |
|
end_frame = num_new_generated_frames * (i_clip + 1) + self.num_input_frames |
|
|
|
|
|
if input_video is not None: |
|
input_frames = input_video[:, :, start_frame:end_frame].cuda() |
|
x0 = self.model.encode(input_frames).contiguous() |
|
x_sigma_max = self.model.get_x_from_clean(x0, self.sigma_max, seed=(self.seed + i_clip)) |
|
else: |
|
assert False |
|
x_sigma_max = None |
|
|
|
data_batch_i[hint_key] = control_input[:, :, start_frame:end_frame].cuda() |
|
latent_hint = [] |
|
log.info("Starting latent encoding") |
|
for b in range(B): |
|
data_batch_p = {k: v for k, v in data_batch_i.items()} |
|
data_batch_p[hint_key] = data_batch_i[hint_key][b : b + 1] |
|
if len(control_inputs_list) >= 1 and len(control_inputs_list[0]) > 1: |
|
latent_hint_i = [] |
|
for idx in range(0, data_batch_p[hint_key].size(1), 3): |
|
x_rgb = data_batch_p[hint_key][:, idx : idx + 3] |
|
latent_hint_i.append(self.model.encode(x_rgb)) |
|
latent_hint.append(torch.cat(latent_hint_i).unsqueeze(0)) |
|
else: |
|
latent_hint.append(self.model.encode_latent(data_batch_p)) |
|
data_batch_i["latent_hint"] = latent_hint = torch.cat(latent_hint) |
|
log.info("Completed latent encoding") |
|
|
|
|
|
if isinstance(control_weight, torch.Tensor) and control_weight.ndim > 4: |
|
control_weight_t = control_weight[..., start_frame:end_frame, :, :] |
|
t, h, w = latent_hint.shape[-3:] |
|
data_batch_i["control_weight"] = resize_control_weight_map(control_weight_t, (t, h // 2, w // 2)) |
|
|
|
num_input_frames = self.num_input_frames |
|
prev_frames_patched = split_video_into_patches( |
|
prev_frames, control_input.shape[-2], control_input.shape[-1] |
|
) |
|
input_frames = prev_frames_patched.bfloat16() / 255.0 * 2 - 1 |
|
condition_latent = self.model.encode(input_frames).contiguous() |
|
|
|
|
|
log.info("Starting diffusion sampling") |
|
latents = generate_world_from_control( |
|
model=self.model, |
|
state_shape=state_shape, |
|
is_negative_prompt=True, |
|
data_batch=data_batch_i, |
|
guidance=self.guidance, |
|
num_steps=self.num_steps, |
|
seed=(self.seed + i_clip), |
|
condition_latent=condition_latent, |
|
num_input_frames=num_input_frames, |
|
sigma_max=self.sigma_max if x_sigma_max is not None else None, |
|
x_sigma_max=x_sigma_max, |
|
) |
|
log.info("Completed diffusion sampling") |
|
|
|
log.info("Starting VAE decode") |
|
frames = self._run_tokenizer_decoding(latents) |
|
log.info("Completed VAE decode") |
|
|
|
if i_clip == 0: |
|
video.append(frames) |
|
else: |
|
video.append(frames[:, :, self.num_input_frames :]) |
|
|
|
prev_frames = torch.zeros_like(frames) |
|
prev_frames[:, :, : self.num_input_frames] = frames[:, :, -self.num_input_frames :] |
|
|
|
video = torch.cat(video, dim=2)[:, :, :T] |
|
video = video.permute(0, 2, 3, 4, 1).numpy() |
|
return video |
|
|