from typing import Tuple import torch from model.base import BaseModel from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper class CausalDiffusion(BaseModel): def __init__(self, args, device): """ Initialize the Diffusion loss module. """ super().__init__(args, device) self.num_frame_per_block = getattr(args, "num_frame_per_block", 1) if self.num_frame_per_block > 1: self.generator.model.num_frame_per_block = self.num_frame_per_block self.independent_first_frame = getattr(args, "independent_first_frame", False) if self.independent_first_frame: self.generator.model.independent_first_frame = True if args.gradient_checkpointing: self.generator.enable_gradient_checkpointing() # Step 2: Initialize all hyperparameters self.num_train_timestep = args.num_train_timestep self.min_step = int(0.02 * self.num_train_timestep) self.max_step = int(0.98 * self.num_train_timestep) self.guidance_scale = args.guidance_scale self.timestep_shift = getattr(args, "timestep_shift", 1.0) self.teacher_forcing = getattr(args, "teacher_forcing", False) # Noise augmentation in teacher forcing, we add small noise to clean context latents self.noise_augmentation_max_timestep = getattr(args, "noise_augmentation_max_timestep", 0) def _initialize_models(self, args): self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True) self.generator.model.requires_grad_(True) self.text_encoder = WanTextEncoder() self.text_encoder.requires_grad_(False) self.vae = WanVAEWrapper() self.vae.requires_grad_(False) def generator_loss( self, image_or_video_shape, conditional_dict: dict, unconditional_dict: dict, clean_latent: torch.Tensor, initial_latent: torch.Tensor = None ) -> Tuple[torch.Tensor, dict]: """ Generate image/videos from noise and compute the DMD loss. The noisy input to the generator is backward simulated. This removes the need of any datasets during distillation. See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details. Input: - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W]. - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings). - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings). - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used. Output: - loss: a scalar tensor representing the generator loss. - generator_log_dict: a dictionary containing the intermediate tensors for logging. """ noise = torch.randn_like(clean_latent) batch_size, num_frame = image_or_video_shape[:2] # Step 2: Randomly sample a timestep and add noise to denoiser inputs index = self._get_timestep( 0, self.scheduler.num_train_timesteps, image_or_video_shape[0], image_or_video_shape[1], self.num_frame_per_block, uniform_timestep=False ) timestep = self.scheduler.timesteps[index].to(dtype=self.dtype, device=self.device) noisy_latents = self.scheduler.add_noise( clean_latent.flatten(0, 1), noise.flatten(0, 1), timestep.flatten(0, 1) ).unflatten(0, (batch_size, num_frame)) training_target = self.scheduler.training_target(clean_latent, noise, timestep) # Step 3: Noise augmentation, also add small noise to clean context latents if self.noise_augmentation_max_timestep > 0: index_clean_aug = self._get_timestep( 0, self.noise_augmentation_max_timestep, image_or_video_shape[0], image_or_video_shape[1], self.num_frame_per_block, uniform_timestep=False ) timestep_clean_aug = self.scheduler.timesteps[index_clean_aug].to(dtype=self.dtype, device=self.device) clean_latent_aug = self.scheduler.add_noise( clean_latent.flatten(0, 1), noise.flatten(0, 1), timestep_clean_aug.flatten(0, 1) ).unflatten(0, (batch_size, num_frame)) else: clean_latent_aug = clean_latent timestep_clean_aug = None # Compute loss flow_pred, x0_pred = self.generator( noisy_image_or_video=noisy_latents, conditional_dict=conditional_dict, timestep=timestep, clean_x=clean_latent_aug if self.teacher_forcing else None, aug_t=timestep_clean_aug if self.teacher_forcing else None ) # loss = torch.nn.functional.mse_loss(flow_pred.float(), training_target.float()) loss = torch.nn.functional.mse_loss( flow_pred.float(), training_target.float(), reduction='none' ).mean(dim=(2, 3, 4)) loss = loss * self.scheduler.training_weight(timestep).unflatten(0, (batch_size, num_frame)) loss = loss.mean() log_dict = { "x0": clean_latent.detach(), "x0_pred": x0_pred.detach() } return loss, log_dict