|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
from megatron.core import parallel_state |
|
from torch import Tensor |
|
|
|
from cosmos_transfer1.diffusion.conditioner import VideoExtendCondition |
|
from cosmos_transfer1.diffusion.config.base.conditioner import VideoCondBoolConfig |
|
from cosmos_transfer1.diffusion.diffusion.functional.batch_ops import batch_mul |
|
from cosmos_transfer1.diffusion.model.model_t2w import DataType, DiffusionT2WModel, DistillT2WModel |
|
from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp |
|
from cosmos_transfer1.utils import log, misc |
|
|
|
|
|
@dataclass |
|
class VideoDenoisePrediction: |
|
x0: torch.Tensor |
|
eps: Optional[torch.Tensor] = None |
|
logvar: Optional[torch.Tensor] = None |
|
xt: Optional[torch.Tensor] = None |
|
x0_pred_replaced: Optional[torch.Tensor] = None |
|
|
|
|
|
class DiffusionV2WModel(DiffusionT2WModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def augment_conditional_latent_frames( |
|
self, |
|
condition: VideoExtendCondition, |
|
cfg_video_cond_bool: VideoCondBoolConfig, |
|
gt_latent: Tensor, |
|
condition_video_augment_sigma_in_inference: float = 0.001, |
|
sigma: Tensor = None, |
|
seed: int = 1, |
|
) -> Union[VideoExtendCondition, Tensor]: |
|
"""Augments the conditional frames with noise during inference. |
|
|
|
Args: |
|
condition (VideoExtendCondition): condition object |
|
condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. |
|
condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. |
|
cfg_video_cond_bool (VideoCondBoolConfig): video condition bool config |
|
gt_latent (Tensor): ground truth latent tensor in shape B,C,T,H,W |
|
condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference |
|
sigma (Tensor): noise level for the generation region |
|
seed (int): random seed for reproducibility |
|
Returns: |
|
VideoExtendCondition: updated condition object |
|
condition_video_augment_sigma: sigma for the condition region, feed to the network |
|
augment_latent (Tensor): augmented latent tensor in shape B,C,T,H,W |
|
|
|
""" |
|
|
|
|
|
assert ( |
|
condition_video_augment_sigma_in_inference is not None |
|
), "condition_video_augment_sigma_in_inference should be provided" |
|
augment_sigma = condition_video_augment_sigma_in_inference |
|
|
|
if augment_sigma >= sigma.flatten()[0]: |
|
|
|
|
|
log.debug("augment_sigma larger than sigma or other frame, remove condition") |
|
condition.condition_video_indicator = condition.condition_video_indicator * 0 |
|
|
|
B = gt_latent.shape[0] |
|
augment_sigma = torch.full((B,), augment_sigma, **self.tensor_kwargs) |
|
|
|
|
|
|
|
noise = misc.arch_invariant_rand( |
|
gt_latent.shape, |
|
torch.float32, |
|
self.tensor_kwargs["device"], |
|
seed, |
|
) |
|
|
|
augment_latent = gt_latent + noise * augment_sigma[:, None, None, None, None] |
|
|
|
_, _, c_in_augment, _ = self.scaling(sigma=augment_sigma) |
|
|
|
|
|
augment_latent_cin = batch_mul(augment_latent, c_in_augment) |
|
|
|
|
|
_, _, c_in, _ = self.scaling(sigma=sigma) |
|
augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in) |
|
|
|
return condition, augment_latent_cin |
|
|
|
def denoise( |
|
self, |
|
noise_x: Tensor, |
|
sigma: Tensor, |
|
condition: VideoExtendCondition, |
|
condition_video_augment_sigma_in_inference: float = 0.001, |
|
seed: int = 1, |
|
) -> VideoDenoisePrediction: |
|
"""Denoises input tensor using conditional video generation. |
|
|
|
Args: |
|
noise_x (Tensor): Noisy input tensor. |
|
sigma (Tensor): Noise level. |
|
condition (VideoExtendCondition): Condition for denoising. |
|
condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference |
|
seed (int): Random seed for reproducibility |
|
Returns: |
|
VideoDenoisePrediction containing: |
|
- x0: Denoised prediction |
|
- eps: Noise prediction |
|
- logvar: Log variance of noise prediction |
|
- xt: Input before c_in multiplication |
|
- x0_pred_replaced: x0 prediction with condition regions replaced by ground truth |
|
""" |
|
|
|
assert ( |
|
condition.gt_latent is not None |
|
), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" |
|
gt_latent = condition.gt_latent |
|
cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool |
|
|
|
condition_latent = gt_latent |
|
|
|
|
|
condition, augment_latent = self.augment_conditional_latent_frames( |
|
condition, cfg_video_cond_bool, condition_latent, condition_video_augment_sigma_in_inference, sigma, seed |
|
) |
|
condition_video_indicator = condition.condition_video_indicator |
|
|
|
if parallel_state.get_context_parallel_world_size() > 1: |
|
cp_group = parallel_state.get_context_parallel_group() |
|
condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) |
|
augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) |
|
gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) |
|
|
|
|
|
new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x |
|
|
|
denoise_pred = super().denoise(new_noise_xt, sigma, condition) |
|
|
|
x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 |
|
|
|
x0_pred = x0_pred_replaced |
|
|
|
return VideoDenoisePrediction( |
|
x0=x0_pred, |
|
eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), |
|
logvar=denoise_pred.logvar, |
|
xt=new_noise_xt, |
|
x0_pred_replaced=x0_pred_replaced, |
|
) |
|
|
|
CHUNKING_MODE = "rand_order" |
|
IS_STAGGERED = True |
|
|
|
def get_chunks_indices(self, total_flen, chunking_size) -> List[torch.Tensor]: |
|
chunks_indices = [] |
|
if self.CHUNKING_MODE == "shuffle": |
|
for index in torch.arange(0, total_flen, 1).split(chunking_size): |
|
chunks_indices.append(index) |
|
np.random.shuffle(chunks_indices) |
|
else: |
|
first_chunk_end = ( |
|
int(torch.randint(low=0, high=chunking_size, size=(1,)) + 1) if self.IS_STAGGERED else chunking_size |
|
) |
|
|
|
if first_chunk_end >= total_flen: |
|
chunks_indices.append(torch.arange(total_flen)) |
|
else: |
|
chunks_indices.append(torch.arange(first_chunk_end)) |
|
|
|
for index in torch.arange(first_chunk_end, total_flen, 1).split(chunking_size): |
|
chunks_indices.append(index) |
|
|
|
if self.CHUNKING_MODE == "in_order": |
|
pass |
|
elif self.CHUNKING_MODE == "rand_order": |
|
if np.random.rand() > 0.5: |
|
chunks_indices = chunks_indices[::-1] |
|
else: |
|
raise NotImplementedError(f"{self.CHUNKING_MODE} mode not implemented!!") |
|
|
|
return chunks_indices |
|
|
|
def generate_samples_from_batch( |
|
self, |
|
data_batch: Dict, |
|
guidance: float = 1.5, |
|
seed: int = 1, |
|
state_shape: Tuple | None = None, |
|
n_sample: int | None = None, |
|
is_negative_prompt: bool = False, |
|
num_steps: int = 35, |
|
condition_latent: Union[torch.Tensor, None] = None, |
|
num_condition_t: Union[int, None] = None, |
|
condition_video_augment_sigma_in_inference: float = None, |
|
add_input_frames_guidance: bool = False, |
|
x_sigma_max: Optional[torch.Tensor] = None, |
|
sigma_max: Optional[float] = None, |
|
chunking: Optional[int] = None, |
|
**kwargs, |
|
) -> Tensor: |
|
"""Generates video samples conditioned on input frames. |
|
|
|
Args: |
|
data_batch: Input data dictionary |
|
guidance: Classifier-free guidance scale |
|
seed: Random seed for reproducibility |
|
state_shape: Shape of output tensor (defaults to model's state shape) |
|
n_sample: Number of samples to generate (defaults to batch size) |
|
is_negative_prompt: Whether to use negative prompting |
|
num_steps: Number of denoising steps |
|
condition_latent: Conditioning frames tensor (B,C,T,H,W) |
|
num_condition_t: Number of frames to condition on |
|
condition_video_augment_sigma_in_inference: Noise level for condition augmentation |
|
add_input_frames_guidance: Whether to apply guidance to input frames |
|
x_sigma_max: Maximum noise level tensor |
|
chunking: Chunking size, if None, chunking is disabled |
|
|
|
Returns: |
|
Generated video samples tensor |
|
""" |
|
|
|
if n_sample is None: |
|
input_key = self.input_data_key |
|
n_sample = data_batch[input_key].shape[0] |
|
if state_shape is None: |
|
log.debug(f"Default Video state shape is used. {self.state_shape}") |
|
state_shape = self.state_shape |
|
|
|
assert condition_latent is not None, "condition_latent should be provided" |
|
|
|
|
|
log.info("x0_fn") |
|
x0_fn = self.get_x0_fn_from_batch_with_condition_latent( |
|
data_batch, |
|
guidance, |
|
is_negative_prompt=is_negative_prompt, |
|
condition_latent=condition_latent, |
|
num_condition_t=num_condition_t, |
|
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, |
|
add_input_frames_guidance=add_input_frames_guidance, |
|
seed=seed, |
|
chunking=chunking, |
|
) |
|
if sigma_max is None: |
|
sigma_max = self.sde.sigma_max |
|
if x_sigma_max is None: |
|
x_sigma_max = ( |
|
misc.arch_invariant_rand( |
|
(n_sample,) + tuple(state_shape), |
|
torch.float32, |
|
self.tensor_kwargs["device"], |
|
seed, |
|
) |
|
* sigma_max |
|
) |
|
|
|
if self.net.is_context_parallel_enabled: |
|
x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) |
|
|
|
samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max) |
|
|
|
if self.net.is_context_parallel_enabled: |
|
samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) |
|
|
|
return samples |
|
|
|
def get_x0_fn_from_batch_with_condition_latent( |
|
self, |
|
data_batch: Dict, |
|
guidance: float = 1.5, |
|
is_negative_prompt: bool = False, |
|
condition_latent: torch.Tensor = None, |
|
num_condition_t: Union[int, None] = None, |
|
condition_video_augment_sigma_in_inference: float = None, |
|
add_input_frames_guidance: bool = False, |
|
seed: int = 1, |
|
chunking: Optional[int] = None, |
|
) -> Callable: |
|
"""Creates denoising function for conditional video generation. |
|
|
|
Args: |
|
data_batch: Input data dictionary |
|
guidance: Classifier-free guidance scale |
|
is_negative_prompt: Whether to use negative prompting |
|
condition_latent: Conditioning frames tensor (B,C,T,H,W) |
|
num_condition_t: Number of frames to condition on |
|
condition_video_augment_sigma_in_inference: Noise level for condition augmentation |
|
add_input_frames_guidance: Whether to apply guidance to input frames |
|
seed: Random seed for reproducibility |
|
chunking: Chunking size, if None, chunking is disabled |
|
|
|
Returns: |
|
Function that takes noisy input and noise level and returns denoised prediction |
|
""" |
|
if chunking is None: |
|
log.info("no chunking") |
|
|
|
if is_negative_prompt: |
|
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) |
|
else: |
|
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) |
|
|
|
condition.video_cond_bool = True |
|
condition = self.add_condition_video_indicator_and_video_input_mask( |
|
condition_latent, condition, num_condition_t |
|
) |
|
|
|
uncondition.video_cond_bool = False if add_input_frames_guidance else True |
|
uncondition = self.add_condition_video_indicator_and_video_input_mask( |
|
condition_latent, uncondition, num_condition_t |
|
) |
|
|
|
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: |
|
cond_x0 = self.denoise( |
|
noise_x, |
|
sigma, |
|
condition, |
|
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, |
|
seed=seed, |
|
).x0_pred_replaced |
|
uncond_x0 = self.denoise( |
|
noise_x, |
|
sigma, |
|
uncondition, |
|
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, |
|
seed=seed, |
|
).x0_pred_replaced |
|
|
|
return cond_x0 + guidance * (cond_x0 - uncond_x0) |
|
|
|
return x0_fn |
|
else: |
|
log.info("chunking !!!") |
|
|
|
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: |
|
if is_negative_prompt: |
|
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) |
|
else: |
|
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) |
|
|
|
noises = torch.zeros_like(condition_latent) |
|
T = condition_latent.shape[2] |
|
for chunk_idx in self.get_chunks_indices(T, chunking): |
|
latents_ = condition_latent[:, :, chunk_idx, :, :] |
|
log.info(f"chunk_idx: {chunk_idx}, chunk shape: {latents_.shape}") |
|
|
|
|
|
condition.video_cond_bool = True |
|
condition = self.add_condition_video_indicator_and_video_input_mask( |
|
latents_, condition, num_condition_t |
|
) |
|
|
|
uncondition.video_cond_bool = False if add_input_frames_guidance else True |
|
uncondition = self.add_condition_video_indicator_and_video_input_mask( |
|
latents_, uncondition, num_condition_t |
|
) |
|
|
|
cond_x0 = self.denoise( |
|
noise_x, |
|
sigma, |
|
condition, |
|
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, |
|
seed=seed, |
|
).x0_pred_replaced |
|
uncond_x0 = self.denoise( |
|
noise_x, |
|
sigma, |
|
uncondition, |
|
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, |
|
seed=seed, |
|
).x0_pred_replaced |
|
|
|
noises[:, :, chunk_idx, :, :] = cond_x0 + guidance * (cond_x0 - uncond_x0) |
|
|
|
|
|
return noises |
|
return x0_fn |
|
|
|
def add_condition_video_indicator_and_video_input_mask( |
|
self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None |
|
) -> VideoExtendCondition: |
|
"""Adds conditioning masks to VideoExtendCondition object. |
|
|
|
Creates binary indicators and input masks for conditional video generation. |
|
|
|
Args: |
|
latent_state: Input latent tensor (B,C,T,H,W) |
|
condition: VideoExtendCondition object to update |
|
num_condition_t: Number of frames to condition on |
|
|
|
Returns: |
|
Updated VideoExtendCondition with added masks: |
|
- condition_video_indicator: Binary tensor marking condition regions |
|
- condition_video_input_mask: Input mask for network |
|
- gt_latent: Ground truth latent tensor |
|
""" |
|
T = latent_state.shape[2] |
|
latent_dtype = latent_state.dtype |
|
condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( |
|
latent_dtype |
|
) |
|
|
|
|
|
assert num_condition_t is not None, "num_condition_t should be provided" |
|
assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" |
|
log.debug( |
|
f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" |
|
) |
|
condition_video_indicator[:, :, :num_condition_t] += 1.0 |
|
|
|
condition.gt_latent = latent_state |
|
condition.condition_video_indicator = condition_video_indicator |
|
|
|
B, C, T, H, W = latent_state.shape |
|
|
|
|
|
ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) |
|
zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) |
|
assert condition.video_cond_bool is not None, "video_cond_bool should be set" |
|
|
|
|
|
if condition.video_cond_bool: |
|
condition.condition_video_input_mask = ( |
|
condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding |
|
) |
|
else: |
|
condition.condition_video_input_mask = zeros_padding |
|
|
|
return condition |
|
|
|
|
|
class DistillV2WModel(DistillT2WModel): |
|
"""ControlNet Video2World Distillation Model.""" |
|
|
|
def augment_conditional_latent_frames( |
|
self, |
|
condition: VideoExtendCondition, |
|
cfg_video_cond_bool: VideoCondBoolConfig, |
|
gt_latent: Tensor, |
|
condition_video_augment_sigma_in_inference: float = 0.001, |
|
sigma: Tensor = None, |
|
seed: int = 1, |
|
) -> Union[VideoExtendCondition, Tensor]: |
|
"""Augments the conditional frames with noise during inference. |
|
|
|
Args: |
|
condition (VideoExtendCondition): condition object |
|
condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. |
|
condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. |
|
cfg_video_cond_bool (VideoCondBoolConfig): video condition bool config |
|
gt_latent (Tensor): ground truth latent tensor in shape B,C,T,H,W |
|
condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference |
|
sigma (Tensor): noise level for the generation region |
|
seed (int): random seed for reproducibility |
|
Returns: |
|
VideoExtendCondition: updated condition object |
|
condition_video_augment_sigma: sigma for the condition region, feed to the network |
|
augment_latent (Tensor): augmented latent tensor in shape B,C,T,H,W |
|
|
|
""" |
|
|
|
assert ( |
|
condition_video_augment_sigma_in_inference is not None |
|
), "condition_video_augment_sigma_in_inference should be provided" |
|
augment_sigma = condition_video_augment_sigma_in_inference |
|
|
|
if augment_sigma >= sigma.flatten()[0]: |
|
|
|
|
|
log.debug("augment_sigma larger than sigma or other frame, remove condition") |
|
condition.condition_video_indicator = condition.condition_video_indicator * 0 |
|
|
|
B = gt_latent.shape[0] |
|
augment_sigma = torch.full((B,), augment_sigma, **self.tensor_kwargs) |
|
|
|
|
|
|
|
noise = misc.arch_invariant_rand( |
|
gt_latent.shape, |
|
torch.float32, |
|
self.tensor_kwargs["device"], |
|
seed, |
|
) |
|
|
|
augment_latent = gt_latent + noise * augment_sigma.view(B, 1, 1, 1, 1) |
|
_, _, c_in_augment, c_noise_augment = self.scaling(sigma=augment_sigma) |
|
|
|
if cfg_video_cond_bool.condition_on_augment_sigma: |
|
if condition.condition_video_indicator.sum() > 0: |
|
condition.condition_video_augment_sigma = c_noise_augment |
|
else: |
|
condition.condition_video_augment_sigma = torch.zeros_like(c_noise_augment) |
|
|
|
|
|
augment_latent_cin = batch_mul(augment_latent, c_in_augment) |
|
|
|
|
|
_, _, c_in, _ = self.scaling(sigma=sigma) |
|
augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in) |
|
|
|
return condition, augment_latent_cin |
|
|
|
def drop_out_condition_region( |
|
self, augment_latent: Tensor, noise_x: Tensor, cfg_video_cond_bool: VideoCondBoolConfig |
|
) -> Tensor: |
|
"""Use for CFG on input frames, we drop out the conditional region |
|
There are two option: |
|
1. when we dropout, we set the region to be zero |
|
2. when we dropout, we set the region to be noise_x |
|
""" |
|
|
|
if cfg_video_cond_bool.cfg_unconditional_type == "zero_condition_region_condition_mask": |
|
|
|
augment_latent_drop = torch.zeros_like(augment_latent) |
|
elif cfg_video_cond_bool.cfg_unconditional_type == "noise_x_condition_region": |
|
|
|
augment_latent_drop = noise_x |
|
else: |
|
raise NotImplementedError( |
|
f"cfg_unconditional_type {cfg_video_cond_bool.cfg_unconditional_type} not implemented" |
|
) |
|
return augment_latent_drop |
|
|
|
def denoise( |
|
self, |
|
noise_x: Tensor, |
|
sigma: Tensor, |
|
condition: VideoExtendCondition, |
|
condition_video_augment_sigma_in_inference: float = 0.001, |
|
seed: int = 1, |
|
) -> VideoDenoisePrediction: |
|
"""Denoises input tensor using conditional video generation. |
|
|
|
Args: |
|
noise_x (Tensor): Noisy input tensor. |
|
sigma (Tensor): Noise level. |
|
condition (VideoExtendCondition): Condition for denoising. |
|
condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference |
|
seed (int): Random seed for reproducibility |
|
Returns: |
|
VideoDenoisePrediction containing: |
|
- x0: Denoised prediction |
|
- eps: Noise prediction |
|
- logvar: Log variance of noise prediction |
|
- xt: Input before c_in multiplication |
|
- x0_pred_replaced: x0 prediction with condition regions replaced by ground truth |
|
""" |
|
inputs_to_check = [noise_x, sigma, condition.gt_latent] |
|
for i, tensor in enumerate(inputs_to_check): |
|
if torch.isnan(tensor).any(): |
|
print(f"NaN found in input {i}") |
|
assert ( |
|
condition.gt_latent is not None |
|
), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" |
|
gt_latent = condition.gt_latent |
|
cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool |
|
|
|
condition_latent = gt_latent |
|
|
|
|
|
condition, augment_latent = self.augment_conditional_latent_frames( |
|
condition, cfg_video_cond_bool, condition_latent, condition_video_augment_sigma_in_inference, sigma, seed |
|
) |
|
condition_video_indicator = condition.condition_video_indicator |
|
|
|
if parallel_state.get_context_parallel_world_size() > 1: |
|
cp_group = parallel_state.get_context_parallel_group() |
|
condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) |
|
augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) |
|
gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) |
|
|
|
if not condition.video_cond_bool: |
|
|
|
augment_latent = self.drop_out_condition_region(augment_latent, xt, cfg_video_cond_bool) |
|
|
|
|
|
new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x |
|
|
|
denoise_pred = super().denoise(new_noise_xt, sigma, condition) |
|
|
|
x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 |
|
|
|
x0_pred = x0_pred_replaced |
|
|
|
return VideoDenoisePrediction( |
|
x0=x0_pred, |
|
eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), |
|
logvar=denoise_pred.logvar, |
|
xt=new_noise_xt, |
|
x0_pred_replaced=x0_pred_replaced, |
|
) |
|
|
|
def add_condition_video_indicator_and_video_input_mask( |
|
self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None |
|
) -> VideoExtendCondition: |
|
"""Adds conditioning masks to VideoExtendCondition object. |
|
|
|
Creates binary indicators and input masks for conditional video generation. |
|
|
|
Args: |
|
latent_state: Input latent tensor (B,C,T,H,W) |
|
condition: VideoExtendCondition object to update |
|
num_condition_t: Number of frames to condition on |
|
|
|
Returns: |
|
Updated VideoExtendCondition with added masks: |
|
- condition_video_indicator: Binary tensor marking condition regions |
|
- condition_video_input_mask: Input mask for network |
|
- gt_latent: Ground truth latent tensor |
|
""" |
|
T = latent_state.shape[2] |
|
latent_dtype = latent_state.dtype |
|
condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( |
|
latent_dtype |
|
) |
|
|
|
|
|
assert num_condition_t is not None, "num_condition_t should be provided" |
|
assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" |
|
log.debug( |
|
f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" |
|
) |
|
condition_video_indicator[:, :, :num_condition_t] += 1.0 |
|
|
|
condition.gt_latent = latent_state |
|
condition.condition_video_indicator = condition_video_indicator |
|
|
|
B, C, T, H, W = latent_state.shape |
|
|
|
ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) |
|
zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) |
|
assert condition.video_cond_bool is not None, "video_cond_bool should be set" |
|
|
|
|
|
if condition.video_cond_bool: |
|
condition.condition_video_input_mask = ( |
|
condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding |
|
) |
|
else: |
|
condition.condition_video_input_mask = zeros_padding |
|
|
|
return condition |
|
|