Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| from diffusers import EDMEulerScheduler | |
| from megatron.core import parallel_state | |
| from torch import Tensor | |
| from cosmos_predict1.diffusion.conditioner import BaseVideoCondition | |
| from cosmos_predict1.diffusion.module import parallel | |
| from cosmos_predict1.diffusion.module.blocks import FourierFeatures | |
| from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp | |
| from cosmos_predict1.diffusion.module.pretrained_vae import BaseVAE | |
| from cosmos_predict1.diffusion.training.utils.layer_control.peft_control_config_parser import LayerControlConfigParser | |
| from cosmos_predict1.diffusion.training.utils.peft.peft import add_lora_layers, setup_lora_requires_grad | |
| from cosmos_predict1.utils import log, misc | |
| from cosmos_predict1.utils.distributed import get_rank | |
| from cosmos_predict1.utils.lazy_config import instantiate as lazy_instantiate | |
| class DiffusionT2WModel(torch.nn.Module): | |
| """Text-to-world diffusion model that generates video frames from text descriptions. | |
| This model implements a diffusion-based approach for generating videos conditioned on text input. | |
| It handles the full pipeline including encoding/decoding through a VAE, diffusion sampling, | |
| and classifier-free guidance. | |
| """ | |
| def __init__(self, config): | |
| """Initialize the diffusion model. | |
| Args: | |
| config: Configuration object containing model parameters and architecture settings | |
| """ | |
| super().__init__() | |
| # Initialize trained_data_record with defaultdict, key: image, video, iteration | |
| self.config = config | |
| self.precision = { | |
| "float32": torch.float32, | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| }[config.precision] | |
| self.tensor_kwargs = {"device": "cuda", "dtype": self.precision} | |
| log.debug(f"DiffusionModel: precision {self.precision}") | |
| # Timer passed to network to detect slow ranks. | |
| # 1. set data keys and data information | |
| self.sigma_data = config.sigma_data | |
| self.state_shape = list(config.latent_shape) | |
| self.setup_data_key() | |
| # 2. setup up diffusion processing and scaling~(pre-condition), sampler | |
| self.scheduler = EDMEulerScheduler(sigma_max=80, sigma_min=0.0002, sigma_data=self.sigma_data) | |
| self.tokenizer = None | |
| self.model = None | |
| def net(self): | |
| return self.model.net | |
| def conditioner(self): | |
| return self.model.conditioner | |
| def logvar(self): | |
| return self.model.logvar | |
| def set_up_tokenizer(self, tokenizer_dir: str): | |
| self.tokenizer: BaseVAE = lazy_instantiate(self.config.tokenizer) | |
| self.tokenizer.load_weights(tokenizer_dir) | |
| if hasattr(self.tokenizer, "reset_dtype"): | |
| self.tokenizer.reset_dtype() | |
| def set_up_model(self, memory_format: torch.memory_format = torch.preserve_format): | |
| """Initialize the core model components including network, conditioner and logvar.""" | |
| self.model = self.build_model() | |
| if self.config.peft_control and self.config.peft_control.enabled: | |
| log.info("Setting up LoRA layers") | |
| peft_control_config_parser = LayerControlConfigParser(config=self.config.peft_control) | |
| peft_control_config = peft_control_config_parser.parse() | |
| add_lora_layers(self.model, peft_control_config) | |
| num_lora_params = setup_lora_requires_grad(self.model) | |
| self.model.requires_grad_(False) | |
| if num_lora_params == 0: | |
| raise ValueError("No LoRA parameters found. Please check the model configuration.") | |
| self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs) | |
| def build_model(self) -> torch.nn.ModuleDict: | |
| """Construct the model's neural network components. | |
| Returns: | |
| ModuleDict containing the network, conditioner and logvar components | |
| """ | |
| config = self.config | |
| net = lazy_instantiate(config.net) | |
| conditioner = lazy_instantiate(config.conditioner) | |
| logvar = torch.nn.Sequential( | |
| FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False) | |
| ) | |
| return torch.nn.ModuleDict( | |
| { | |
| "net": net, | |
| "conditioner": conditioner, | |
| "logvar": logvar, | |
| } | |
| ) | |
| def encode(self, state: torch.Tensor) -> torch.Tensor: | |
| """Encode input state into latent representation using VAE. | |
| Args: | |
| state: Input tensor to encode | |
| Returns: | |
| Encoded latent representation scaled by sigma_data | |
| """ | |
| return self.tokenizer.encode(state) * self.sigma_data | |
| def decode(self, latent: torch.Tensor) -> torch.Tensor: | |
| """Decode latent representation back to pixel space using VAE. | |
| Args: | |
| latent: Latent tensor to decode | |
| Returns: | |
| Decoded tensor in pixel space | |
| """ | |
| return self.tokenizer.decode(latent / self.sigma_data) | |
| def setup_data_key(self) -> None: | |
| """Configure input data keys for video and image data.""" | |
| self.input_data_key = self.config.input_data_key # by default it is video key for Video diffusion model | |
| def generate_samples_from_batch( | |
| self, | |
| data_batch: dict, | |
| guidance: float = 1.5, | |
| seed: int = 1, | |
| state_shape: tuple | None = None, | |
| n_sample: int | None = 1, | |
| is_negative_prompt: bool = False, | |
| num_steps: int = 35, | |
| ) -> Tensor: | |
| """Generate samples from a data batch using diffusion sampling. | |
| This function generates samples from either image or video data batches using diffusion sampling. | |
| It handles both conditional and unconditional generation with classifier-free guidance. | |
| Args: | |
| data_batch (dict): Raw data batch from the training data loader | |
| guidance (float, optional): Classifier-free guidance weight. Defaults to 1.5. | |
| seed (int, optional): Random seed for reproducibility. Defaults to 1. | |
| state_shape (tuple | None, optional): Shape of the state tensor. Uses self.state_shape if None. Defaults to None. | |
| n_sample (int | None, optional): Number of samples to generate. Defaults to 1. | |
| is_negative_prompt (bool, optional): Whether to use negative prompt for unconditional generation. Defaults to False. | |
| num_steps (int, optional): Number of diffusion sampling steps. Defaults to 35. | |
| Returns: | |
| Tensor: Generated samples after diffusion sampling | |
| """ | |
| condition, uncondition = self._get_conditions(data_batch, is_negative_prompt) | |
| self.scheduler.set_timesteps(num_steps) | |
| xt = torch.randn(size=(n_sample,) + tuple(state_shape)) * self.scheduler.init_noise_sigma | |
| to_cp = self.net.is_context_parallel_enabled | |
| if to_cp: | |
| xt = split_inputs_cp(x=xt, seq_dim=2, cp_group=self.net.cp_group) | |
| for t in self.scheduler.timesteps: | |
| xt = xt.to(**self.tensor_kwargs) | |
| xt_scaled = self.scheduler.scale_model_input(xt, timestep=t) | |
| # Predict the noise residual | |
| t = t.to(**self.tensor_kwargs) | |
| net_output_cond = self.net(x=xt_scaled, timesteps=t, **condition.to_dict()) | |
| net_output_uncond = self.net(x=xt_scaled, timesteps=t, **uncondition.to_dict()) | |
| net_output = net_output_cond + guidance * (net_output_cond - net_output_uncond) | |
| # Compute the previous noisy sample x_t -> x_t-1 | |
| xt = self.scheduler.step(net_output, t, xt).prev_sample | |
| samples = xt | |
| if to_cp: | |
| samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) | |
| return samples | |
| def _get_conditions( | |
| self, | |
| data_batch: dict, | |
| is_negative_prompt: bool = False, | |
| ): | |
| """Get the conditions for the model. | |
| Args: | |
| data_batch: Input data dictionary | |
| is_negative_prompt: Whether to use negative prompting | |
| Returns: | |
| condition: Input conditions | |
| uncondition: Conditions removed/reduced to minimum (unconditioned) | |
| """ | |
| if is_negative_prompt: | |
| condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) | |
| else: | |
| condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) | |
| to_cp = self.net.is_context_parallel_enabled | |
| # For inference, check if parallel_state is initialized | |
| if parallel_state.is_initialized(): | |
| condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) | |
| uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) | |
| return condition, uncondition | |
| def broadcast_condition(condition: BaseVideoCondition, to_tp: bool = True, to_cp: bool = True) -> BaseVideoCondition: | |
| condition_kwargs = {} | |
| for k, v in condition.to_dict().items(): | |
| if isinstance(v, torch.Tensor): | |
| assert not v.requires_grad, f"{k} requires gradient. the current impl does not support it" | |
| condition_kwargs[k] = parallel.broadcast(v, to_tp=to_tp, to_cp=to_cp) | |
| condition = type(condition)(**condition_kwargs) | |
| return condition | |