from __future__ import annotations import inspect import math import time import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union from dataclasses import dataclass from einops import rearrange, repeat import PIL.Image import numpy as np import torch from torch import nn import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers.pipelines.controlnet.pipeline_controlnet import ( StableDiffusionSafetyChecker, EXAMPLE_DOC_STRING, ) from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import ( StableDiffusionControlNetImg2ImgPipeline as DiffusersStableDiffusionControlNetImg2ImgPipeline, ) from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL, ControlNetModel from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( deprecate, logging, BaseOutput, replace_example_docstring, ) from diffusers.utils.torch_utils import is_compiled_module from diffusers.loaders import TextualInversionLoaderMixin from diffusers.models.attention import ( BasicTransformerBlock as DiffusersBasicTransformerBlock, ) from mmcm.vision.process.correct_color import ( hist_match_color_video_batch, hist_match_video_bcthw, ) from ..models.attention import BasicTransformerBlock from ..models.unet_3d_condition import UNet3DConditionModel from ..utils.noise_util import random_noise, video_fusion_noise from ..data.data_util import ( adaptive_instance_normalization, align_repeat_tensor_single_dim, batch_adain_conditioned_tensor, batch_concat_two_tensor_with_index, batch_index_select, fuse_part_tensor, ) from ..utils.text_emb_util import encode_weighted_prompt from ..utils.tensor_util import his_match from ..utils.timesteps_util import generate_parameters_with_timesteps from .context import get_context_scheduler, prepare_global_context logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class VideoPipelineOutput(BaseOutput): videos: Union[torch.Tensor, np.ndarray] latents: Union[torch.Tensor, np.ndarray] videos_mid: Union[torch.Tensor, np.ndarray] down_block_res_samples: Tuple[torch.FloatTensor] = None mid_block_res_samples: torch.FloatTensor = None up_block_res_samples: torch.FloatTensor = None mid_video_latents: List[torch.FloatTensor] = None mid_video_noises: List[torch.FloatTensor] = None def torch_dfs(model: torch.nn.Module): result = [model] for child in model.children(): result += torch_dfs(child) return result def prepare_image( image, # b c t h w batch_size, device, dtype, image_processor: Callable, num_images_per_prompt: int = 1, width=None, height=None, ): if isinstance(image, List) and isinstance(image[0], str): raise NotImplementedError if isinstance(image, List) and isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if isinstance(image, np.ndarray): image = torch.from_numpy(image) if image.ndim == 5: image = rearrange(image, "b c t h w-> (b t) c h w") if height is None: height = image.shape[-2] if width is None: width = image.shape[-1] width, height = (x - x % image_processor.vae_scale_factor for x in (width, height)) if height != image.shape[-2] or width != image.shape[-1]: image = torch.nn.functional.interpolate( image, size=(height, width), mode="bilinear" ) image = image.to(dtype=torch.float32) / 255.0 do_normalize = image_processor.config.do_normalize if image.min() < 0: warnings.warn( "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", FutureWarning, ) do_normalize = False if do_normalize: image = image_processor.normalize(image) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) return image class MusevControlNetPipeline( DiffusersStableDiffusionControlNetImg2ImgPipeline, TextualInversionLoaderMixin ): """ a union diffusers pipeline, support 1. text2image model only, or text2video model, by setting skip_temporal_layer 2. text2video, image2video, video2video; 3. multi controlnet 4. IPAdapter 5. referencenet 6. IPAdapterFaceID """ _optional_components = [ "safety_checker", "feature_extractor", ] print_idx = 0 def __init__( self, vae: AutoencoderKL, unet: UNet3DConditionModel, scheduler: KarrasDiffusionSchedulers, controlnet: ControlNetModel | List[ControlNetModel] | Tuple[ControlNetModel] | MultiControlNetModel, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, # | MultiControlNetModel = None, # text_encoder: CLIPTextModel = None, # tokenizer: CLIPTokenizer = None, # safety_checker: StableDiffusionSafetyChecker = None, # feature_extractor: CLIPImageProcessor = None, requires_safety_checker: bool = False, referencenet: nn.Module = None, vision_clip_extractor: nn.Module = None, ip_adapter_image_proj: nn.Module = None, face_emb_extractor: nn.Module = None, facein_image_proj: nn.Module = None, ip_adapter_face_emb_extractor: nn.Module = None, ip_adapter_face_image_proj: nn.Module = None, pose_guider: nn.Module = None, ): super().__init__( vae, text_encoder, tokenizer, unet, controlnet, scheduler, safety_checker, feature_extractor, requires_safety_checker, ) self.referencenet = referencenet # ip_adapter if isinstance(vision_clip_extractor, nn.Module): vision_clip_extractor.to(dtype=self.unet.dtype, device=self.unet.device) self.vision_clip_extractor = vision_clip_extractor if isinstance(ip_adapter_image_proj, nn.Module): ip_adapter_image_proj.to(dtype=self.unet.dtype, device=self.unet.device) self.ip_adapter_image_proj = ip_adapter_image_proj # facein if isinstance(face_emb_extractor, nn.Module): face_emb_extractor.to(dtype=self.unet.dtype, device=self.unet.device) self.face_emb_extractor = face_emb_extractor if isinstance(facein_image_proj, nn.Module): facein_image_proj.to(dtype=self.unet.dtype, device=self.unet.device) self.facein_image_proj = facein_image_proj # ip_adapter_face if isinstance(ip_adapter_face_emb_extractor, nn.Module): ip_adapter_face_emb_extractor.to( dtype=self.unet.dtype, device=self.unet.device ) self.ip_adapter_face_emb_extractor = ip_adapter_face_emb_extractor if isinstance(ip_adapter_face_image_proj, nn.Module): ip_adapter_face_image_proj.to( dtype=self.unet.dtype, device=self.unet.device ) self.ip_adapter_face_image_proj = ip_adapter_face_image_proj if isinstance(pose_guider, nn.Module): pose_guider.to(dtype=self.unet.dtype, device=self.unet.device) self.pose_guider = pose_guider def decode_latents(self, latents): batch_size = latents.shape[0] latents = rearrange(latents, "b c f h w -> (b f) c h w") video = super().decode_latents(latents=latents) video = rearrange(video, "(b f) h w c -> b c f h w", b=batch_size) return video def prepare_latents( self, batch_size: int, num_channels_latents: int, video_length: int, height: int, width: int, dtype: torch.dtype, device: torch.device, generator: torch.Generator, latents: torch.Tensor = None, w_ind_noise: float = 0.5, image: torch.Tensor = None, timestep: int = None, initial_common_latent: torch.Tensor = None, noise_type: str = "random", add_latents_noise: bool = False, need_img_based_video_noise: bool = False, condition_latents: torch.Tensor = None, img_weight=1e-3, ) -> torch.Tensor: """ 支持多种情况下的latens: img_based_latents: 当Image t=1,latents=None时,使用image赋值到shape,然后加噪;适用于text2video、middle2video。 video_based_latents:image =shape或Latents!=None时,加噪,适用于video2video; noise_latents:当image 和latents都为None时,生成随机噪声,适用于text2video support multi latents condition: img_based_latents: when Image t=1, latents=None, use image to assign to shape, then add noise; suitable for text2video, middle2video. video_based_latents: image =shape or Latents!=None, add noise, suitable for video2video; noise_laten: when image and latents are both None, generate random noise, suitable for text2video Args: batch_size (int): _description_ num_channels_latents (int): _description_ video_length (int): _description_ height (int): _description_ width (int): _description_ dtype (torch.dtype): _description_ device (torch.device): _description_ generator (torch.Generator): _description_ latents (torch.Tensor, optional): _description_. Defaults to None. w_ind_noise (float, optional): _description_. Defaults to 0.5. image (torch.Tensor, optional): _description_. Defaults to None. timestep (int, optional): _description_. Defaults to None. initial_common_latent (torch.Tensor, optional): _description_. Defaults to None. noise_type (str, optional): _description_. Defaults to "random". add_latents_noise (bool, optional): _description_. Defaults to False. need_img_based_video_noise (bool, optional): _description_. Defaults to False. condition_latents (torch.Tensor, optional): _description_. Defaults to None. img_weight (_type_, optional): _description_. Defaults to 1e-3. Raises: ValueError: _description_ ValueError: _description_ ValueError: _description_ Returns: torch.Tensor: latents """ # ref https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py#L691 # ref https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet.py#L659 shape = ( batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None or (latents is not None and add_latents_noise): if noise_type == "random": noise = random_noise( shape=shape, dtype=dtype, device=device, generator=generator ) elif noise_type == "video_fusion": noise = video_fusion_noise( shape=shape, dtype=dtype, device=device, generator=generator, w_ind_noise=w_ind_noise, initial_common_noise=initial_common_latent, ) if ( need_img_based_video_noise and condition_latents is not None and image is None and latents is None ): if self.print_idx == 0: logger.debug( ( f"need_img_based_video_noise, condition_latents={condition_latents.shape}," f"batch_size={batch_size}, noise={noise.shape}, video_length={video_length}" ) ) condition_latents = condition_latents.mean(dim=2, keepdim=True) condition_latents = repeat( condition_latents, "b c t h w->b c (t x) h w", x=video_length ) noise = ( img_weight**0.5 * condition_latents + (1 - img_weight) ** 0.5 * noise ) if self.print_idx == 0: logger.debug(f"noise={noise.shape}") if image is not None: if image.ndim == 5: image = rearrange(image, "b c t h w->(b t) c h w") image = image.to(device=device, dtype=dtype) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): init_latents = [ # self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) self.vae.encode(image[i : i + 1]).latent_dist.mean for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: # init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.encode(image).latent_dist.mean init_latents = self.vae.config.scaling_factor * init_latents # scale the initial noise by the standard deviation required by the scheduler if ( batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0 ): # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate( "len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False, ) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat( [init_latents] * additional_image_per_prompt, dim=0 ) elif ( batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0 ): raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) if init_latents.shape[2] != shape[3] and init_latents.shape[3] != shape[4]: init_latents = torch.nn.functional.interpolate( init_latents, size=(shape[3], shape[4]), mode="bilinear", ) init_latents = rearrange( init_latents, "(b t) c h w-> b c t h w", t=video_length ) if self.print_idx == 0: logger.debug(f"init_latensts={init_latents.shape}") if latents is None: if image is None: latents = noise * self.scheduler.init_noise_sigma else: if self.print_idx == 0: logger.debug(f"prepare latents, image is not None") latents = self.scheduler.add_noise(init_latents, noise, timestep) else: if isinstance(latents, np.ndarray): latents = torch.from_numpy(latents) latents = latents.to(device=device, dtype=dtype) if add_latents_noise: latents = self.scheduler.add_noise(latents, noise, timestep) else: latents = latents * self.scheduler.init_noise_sigma if latents.shape != shape: raise ValueError( f"Unexpected latents shape, got {latents.shape}, expected {shape}" ) latents = latents.to(device, dtype=dtype) return latents def prepare_image( self, image, # b c t h w batch_size, num_images_per_prompt, device, dtype, width=None, height=None, ): return prepare_image( image=image, batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, device=device, dtype=dtype, width=width, height=height, image_processor=self.image_processor, ) def prepare_control_image( self, image, # b c t h w width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): image = prepare_image( image=image, batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, device=device, dtype=dtype, width=width, height=height, image_processor=self.control_image_processor, ) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image def check_inputs( self, prompt, image, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1, control_guidance_start=0, control_guidance_end=1, ): # TODO: to implement if image is not None: return super().check_inputs( prompt, image, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, ) def hist_match_with_vis_cond( self, video: np.ndarray, target: np.ndarray ) -> np.ndarray: """ video: b c t1 h w target: b c t2(=1) h w """ video = hist_match_video_bcthw(video, target, value=255.0) return video def get_facein_image_emb( self, refer_face_image, device, dtype, batch_size, do_classifier_free_guidance ): # refer_face_image and its face_emb if self.print_idx == 0: logger.debug( f"face_emb_extractor={type(self.face_emb_extractor)}, facein_image_proj={type(self.facein_image_proj)}, refer_face_image={type(refer_face_image)}, " ) if ( self.face_emb_extractor is not None and self.facein_image_proj is not None and refer_face_image is not None ): if self.print_idx == 0: logger.debug(f"refer_face_image={refer_face_image.shape}") if isinstance(refer_face_image, np.ndarray): refer_face_image = torch.from_numpy(refer_face_image) refer_face_image_facein = refer_face_image n_refer_face_image = refer_face_image_facein.shape[2] refer_face_image_facein = rearrange( refer_face_image, "b c t h w-> (b t) h w c" ) # refer_face_image_emb: bt d或者 bt h w d ( refer_face_image_emb, refer_align_face_image, ) = self.face_emb_extractor.extract_images( refer_face_image_facein, return_type="torch" ) refer_face_image_emb = refer_face_image_emb.to(device=device, dtype=dtype) if self.print_idx == 0: logger.debug(f"refer_face_image_emb={refer_face_image_emb.shape}") if refer_face_image_emb.shape == 2: refer_face_image_emb = rearrange(refer_face_image_emb, "bt d-> bt 1 d") elif refer_face_image_emb.shape == 4: refer_face_image_emb = rearrange( refer_face_image_emb, "bt h w d-> bt (h w) d" ) refer_face_image_emb_bk = refer_face_image_emb refer_face_image_emb = self.facein_image_proj(refer_face_image_emb) # Todo:当前不支持 IPAdapterPlus的vision_clip的输出 refer_face_image_emb = rearrange( refer_face_image_emb, "(b t) n q-> b (t n) q", t=n_refer_face_image, ) refer_face_image_emb = align_repeat_tensor_single_dim( refer_face_image_emb, target_length=batch_size, dim=0 ) if do_classifier_free_guidance: # TODO:固定特征,有优化空间 # TODO: fix the feature, there is optimization space uncond_refer_face_image_emb = self.facein_image_proj( torch.zeros_like(refer_face_image_emb_bk).to( device=device, dtype=dtype ) ) # Todo:当前可能不支持 IPAdapterPlus的vision_clip的输出 # TODO: do not support IPAdapterPlus's vision_clip's output uncond_refer_face_image_emb = rearrange( uncond_refer_face_image_emb, "(b t) n q-> b (t n) q", t=n_refer_face_image, ) uncond_refer_face_image_emb = align_repeat_tensor_single_dim( uncond_refer_face_image_emb, target_length=batch_size, dim=0 ) if self.print_idx == 0: logger.debug( f"uncond_refer_face_image_emb, {uncond_refer_face_image_emb.shape}" ) logger.debug(f"refer_face_image_emb, {refer_face_image_emb.shape}") refer_face_image_emb = torch.concat( [ uncond_refer_face_image_emb, refer_face_image_emb, ], ) else: refer_face_image_emb = None if self.print_idx == 0: logger.debug(f"refer_face_image_emb={type(refer_face_image_emb)}") return refer_face_image_emb def get_ip_adapter_face_emb( self, refer_face_image, device, dtype, batch_size, do_classifier_free_guidance ): # refer_face_image and its ip_adapter_face_emb if self.print_idx == 0: logger.debug( f"face_emb_extractor={type(self.face_emb_extractor)}, ip_adapter__image_proj={type(self.facein_image_proj)}, refer_face_image={type(refer_face_image)}, " ) if ( self.ip_adapter_face_emb_extractor is not None and self.ip_adapter_face_image_proj is not None and refer_face_image is not None ): if self.print_idx == 0: logger.debug(f"refer_face_image={refer_face_image.shape}") if isinstance(refer_face_image, np.ndarray): refer_face_image = torch.from_numpy(refer_face_image) refer_ip_adapter_face_image = refer_face_image n_refer_face_image = refer_ip_adapter_face_image.shape[2] refer_ip_adapter_face_image = rearrange( refer_ip_adapter_face_image, "b c t h w-> (b t) h w c" ) # refer_face_image_emb: bt d or bt h w d ( refer_face_image_emb, refer_align_face_image, ) = self.ip_adapter_face_emb_extractor.extract_images( refer_ip_adapter_face_image, return_type="torch" ) refer_face_image_emb = refer_face_image_emb.to(device=device, dtype=dtype) if self.print_idx == 0: logger.debug(f"refer_face_image_emb={refer_face_image_emb.shape}") if refer_face_image_emb.shape == 2: refer_face_image_emb = rearrange(refer_face_image_emb, "bt d-> bt 1 d") elif refer_face_image_emb.shape == 4: refer_face_image_emb = rearrange( refer_face_image_emb, "bt h w d-> bt (h w) d" ) refer_face_image_emb_bk = refer_face_image_emb refer_face_image_emb = self.ip_adapter_face_image_proj(refer_face_image_emb) refer_face_image_emb = rearrange( refer_face_image_emb, "(b t) n q-> b (t n) q", t=n_refer_face_image, ) refer_face_image_emb = align_repeat_tensor_single_dim( refer_face_image_emb, target_length=batch_size, dim=0 ) if do_classifier_free_guidance: # TODO:固定特征,有优化空间 # TODO: fix the feature, there is optimization space uncond_refer_face_image_emb = self.ip_adapter_face_image_proj( torch.zeros_like(refer_face_image_emb_bk).to( device=device, dtype=dtype ) ) # TODO: 当前可能不支持 IPAdapterPlus的vision_clip的输出 # TODO: do not support IPAdapterPlus's vision_clip's output uncond_refer_face_image_emb = rearrange( uncond_refer_face_image_emb, "(b t) n q-> b (t n) q", t=n_refer_face_image, ) uncond_refer_face_image_emb = align_repeat_tensor_single_dim( uncond_refer_face_image_emb, target_length=batch_size, dim=0 ) if self.print_idx == 0: logger.debug( f"uncond_refer_face_image_emb, {uncond_refer_face_image_emb.shape}" ) logger.debug(f"refer_face_image_emb, {refer_face_image_emb.shape}") refer_face_image_emb = torch.concat( [ uncond_refer_face_image_emb, refer_face_image_emb, ], ) else: refer_face_image_emb = None if self.print_idx == 0: logger.debug(f"ip_adapter_face_emb={type(refer_face_image_emb)}") return refer_face_image_emb def get_ip_adapter_image_emb( self, ip_adapter_image, device, dtype, batch_size, do_classifier_free_guidance, height, width, ): # refer_image vision_clip and its ipadapter_emb if self.print_idx == 0: logger.debug( f"vision_clip_extractor={type(self.vision_clip_extractor)}," f"ip_adapter_image_proj={type(self.ip_adapter_image_proj)}," f"ip_adapter_image={type(ip_adapter_image)}," ) if self.vision_clip_extractor is not None and ip_adapter_image is not None: if self.print_idx == 0: logger.debug(f"ip_adapter_image={ip_adapter_image.shape}") if isinstance(ip_adapter_image, np.ndarray): ip_adapter_image = torch.from_numpy(ip_adapter_image) # ip_adapter_image = ip_adapter_image.to(device=device, dtype=dtype) n_ip_adapter_image = ip_adapter_image.shape[2] ip_adapter_image = rearrange(ip_adapter_image, "b c t h w-> (b t) h w c") ip_adapter_image_emb = self.vision_clip_extractor.extract_images( ip_adapter_image, target_height=height, target_width=width, return_type="torch", ) if ip_adapter_image_emb.ndim == 2: ip_adapter_image_emb = rearrange(ip_adapter_image_emb, "b q-> b 1 q") ip_adapter_image_emb_bk = ip_adapter_image_emb # 存在只需要image_prompt、但不需要 proj的场景,如使用image_prompt替代text_prompt # There are scenarios where only image_prompt is needed, but proj is not needed, such as using image_prompt instead of text_prompt if self.ip_adapter_image_proj is not None: logger.debug(f"ip_adapter_image_proj is None, ") ip_adapter_image_emb = self.ip_adapter_image_proj(ip_adapter_image_emb) # TODO: 当前不支持 IPAdapterPlus的vision_clip的输出 # TODO: do not support IPAdapterPlus's vision_clip's output ip_adapter_image_emb = rearrange( ip_adapter_image_emb, "(b t) n q-> b (t n) q", t=n_ip_adapter_image, ) ip_adapter_image_emb = align_repeat_tensor_single_dim( ip_adapter_image_emb, target_length=batch_size, dim=0 ) if do_classifier_free_guidance: # TODO:固定特征,有优化空间 # TODO: fix the feature, there is optimization space if self.ip_adapter_image_proj is not None: uncond_ip_adapter_image_emb = self.ip_adapter_image_proj( torch.zeros_like(ip_adapter_image_emb_bk).to( device=device, dtype=dtype ) ) if self.print_idx == 0: logger.debug( f"uncond_ip_adapter_image_emb use ip_adapter_image_proj(zero_like)" ) else: uncond_ip_adapter_image_emb = torch.zeros_like(ip_adapter_image_emb) if self.print_idx == 0: logger.debug(f"uncond_ip_adapter_image_emb use zero_like") # TODO:当前可能不支持 IPAdapterPlus的vision_clip的输出 # TODO: do not support IPAdapterPlus's vision_clip's output uncond_ip_adapter_image_emb = rearrange( uncond_ip_adapter_image_emb, "(b t) n q-> b (t n) q", t=n_ip_adapter_image, ) uncond_ip_adapter_image_emb = align_repeat_tensor_single_dim( uncond_ip_adapter_image_emb, target_length=batch_size, dim=0 ) if self.print_idx == 0: logger.debug( f"uncond_ip_adapter_image_emb, {uncond_ip_adapter_image_emb.shape}" ) logger.debug(f"ip_adapter_image_emb, {ip_adapter_image_emb.shape}") # uncond_ip_adapter_image_emb = torch.zeros_like(ip_adapter_image_emb) ip_adapter_image_emb = torch.concat( [ uncond_ip_adapter_image_emb, ip_adapter_image_emb, ], ) else: ip_adapter_image_emb = None if self.print_idx == 0: logger.debug(f"ip_adapter_image_emb={type(ip_adapter_image_emb)}") return ip_adapter_image_emb def get_referencenet_image_vae_emb( self, refer_image, batch_size, num_videos_per_prompt, device, dtype, do_classifier_free_guidance, width: int = None, height: int = None, ): # prepare_referencenet_emb if self.print_idx == 0: logger.debug( f"referencenet={type(self.referencenet)}, refer_image={type(refer_image)}" ) if self.referencenet is not None and refer_image is not None: n_refer_image = refer_image.shape[2] refer_image_vae = self.prepare_image( refer_image, batch_size=batch_size * num_videos_per_prompt, num_images_per_prompt=num_videos_per_prompt, device=device, dtype=dtype, width=width, height=height, ) # ref_hidden_states = self.vae.encode(refer_image_vae).latent_dist.sample() refer_image_vae_emb = self.vae.encode(refer_image_vae).latent_dist.mean refer_image_vae_emb = self.vae.config.scaling_factor * refer_image_vae_emb logger.debug(f"refer_image_vae_emb={refer_image_vae_emb.shape}") if do_classifier_free_guidance: # 1. zeros_like image # uncond_refer_image_vae_emb = self.vae.encode( # torch.zeros_like(refer_image_vae) # ).latent_dist.mean # uncond_refer_image_vae_emb = ( # self.vae.config.scaling_factor * uncond_refer_image_vae_emb # ) # 2. zeros_like image vae emb # uncond_refer_image_vae_emb = torch.zeros_like(refer_image_vae_emb) # uncond_refer_image_vae_emb = rearrange( # uncond_refer_image_vae_emb, # "(b t) c h w-> b c t h w", # t=n_refer_image, # ) # refer_image_vae_emb = rearrange( # refer_image_vae_emb, "(b t) c h w-> b c t h w", t=n_refer_image # ) # refer_image_vae_emb = torch.concat( # [uncond_refer_image_vae_emb, refer_image_vae_emb], dim=0 # ) # refer_image_vae_emb = rearrange( # refer_image_vae_emb, "b c t h w-> (b t) c h w" # ) # logger.debug(f"refer_image_vae_emb={refer_image_vae_emb.shape}") # 3. uncond_refer_image_vae_emb = refer_image_vae_emb uncond_refer_image_vae_emb = refer_image_vae_emb uncond_refer_image_vae_emb = rearrange( uncond_refer_image_vae_emb, "(b t) c h w-> b c t h w", t=n_refer_image, ) refer_image_vae_emb = rearrange( refer_image_vae_emb, "(b t) c h w-> b c t h w", t=n_refer_image ) refer_image_vae_emb = torch.concat( [uncond_refer_image_vae_emb, refer_image_vae_emb], dim=0 ) refer_image_vae_emb = rearrange( refer_image_vae_emb, "b c t h w-> (b t) c h w" ) logger.debug(f"refer_image_vae_emb={refer_image_vae_emb.shape}") else: refer_image_vae_emb = None return refer_image_vae_emb def get_referencenet_emb( self, refer_image_vae_emb, refer_image, batch_size, num_videos_per_prompt, device, dtype, ip_adapter_image_emb, do_classifier_free_guidance, prompt_embeds, ref_timestep_int: int = 0, ): # prepare_referencenet_emb if self.print_idx == 0: logger.debug( f"referencenet={type(self.referencenet)}, refer_image={type(refer_image)}" ) if ( self.referencenet is not None and refer_image_vae_emb is not None and refer_image is not None ): n_refer_image = refer_image.shape[2] # ref_timestep = ( # torch.ones((refer_image_vae_emb.shape[0],), device=device) # * ref_timestep_int # ) ref_timestep = torch.zeros_like(ref_timestep_int) # referencenet 优先使用 ip_adapter 中图像提取到的 clip_vision_emb if ip_adapter_image_emb is not None: refer_prompt_embeds = ip_adapter_image_emb else: refer_prompt_embeds = prompt_embeds if self.print_idx == 0: logger.debug( f"use referencenet: n_refer_image={n_refer_image}, refer_image_vae_emb={refer_image_vae_emb.shape}, ref_timestep={ref_timestep.shape}" ) if prompt_embeds is not None: logger.debug(f"prompt_embeds={prompt_embeds.shape},") # refer_image_vae_emb = self.scheduler.scale_model_input( # refer_image_vae_emb, ref_timestep # ) # self.scheduler._step_index = None # self.scheduler.is_scale_input_called = False referencenet_params = { "sample": refer_image_vae_emb, "encoder_hidden_states": refer_prompt_embeds, "timestep": ref_timestep, "num_frames": n_refer_image, "return_ndim": 5, } ( down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb, ) = self.referencenet(**referencenet_params) # many ways to prepare negative referencenet emb # mode 1 # zero shape like ref_image # if do_classifier_free_guidance: # # mode 2: # # if down_block_refer_embs is not None: # # down_block_refer_embs = [ # # torch.cat([x] * 2) for x in down_block_refer_embs # # ] # # if mid_block_refer_emb is not None: # # mid_block_refer_emb = torch.cat([mid_block_refer_emb] * 2) # # if refer_self_attn_emb is not None: # # refer_self_attn_emb = [ # # torch.cat([x] * 2) for x in refer_self_attn_emb # # ] # # mode 3 # if down_block_refer_embs is not None: # down_block_refer_embs = [ # torch.cat([torch.zeros_like(x), x]) # for x in down_block_refer_embs # ] # if mid_block_refer_emb is not None: # mid_block_refer_emb = torch.cat( # [torch.zeros_like(mid_block_refer_emb), mid_block_refer_emb] * 2 # ) # if refer_self_attn_emb is not None: # refer_self_attn_emb = [ # torch.cat([torch.zeros_like(x), x]) for x in refer_self_attn_emb # ] else: down_block_refer_embs = None mid_block_refer_emb = None refer_self_attn_emb = None if self.print_idx == 0: logger.debug(f"down_block_refer_embs={type(down_block_refer_embs)}") logger.debug(f"mid_block_refer_emb={type(mid_block_refer_emb)}") logger.debug(f"refer_self_attn_emb={type(refer_self_attn_emb)}") return down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb def prepare_condition_latents_and_index( self, condition_images, condition_latents, video_length, batch_size, dtype, device, latent_index, vision_condition_latent_index, ): # prepare condition_latents if condition_images is not None and condition_latents is None: # condition_latents = self.vae.encode(condition_images).latent_dist.sample() condition_latents = self.vae.encode(condition_images).latent_dist.mean condition_latents = self.vae.config.scaling_factor * condition_latents condition_latents = rearrange( condition_latents, "(b t) c h w-> b c t h w", b=batch_size ) if self.print_idx == 0: logger.debug( f"condition_latents from condition_images, shape is condition_latents={condition_latents.shape}", ) if condition_latents is not None: total_frames = condition_latents.shape[2] + video_length if isinstance(condition_latents, np.ndarray): condition_latents = torch.from_numpy(condition_latents) condition_latents = condition_latents.to(dtype=dtype, device=device) # if condition is None, mean condition_latents head, generated video is tail if vision_condition_latent_index is not None: # vision_condition_latent_index should be list, whose length is condition_latents.shape[2] # -1 -> will be converted to condition_latents.shape[2]+video_length vision_condition_latent_index_lst = [ i_v if i_v != -1 else total_frames - 1 for i_v in vision_condition_latent_index ] vision_condition_latent_index = torch.LongTensor( vision_condition_latent_index_lst, ).to(device=device) if self.print_idx == 0: logger.debug( f"vision_condition_latent_index {type(vision_condition_latent_index)}, {vision_condition_latent_index}" ) else: # [0, condition_latents.shape[2]] vision_condition_latent_index = torch.arange( condition_latents.shape[2], dtype=torch.long, device=device ) vision_condition_latent_index_lst = ( vision_condition_latent_index.tolist() ) if latent_index is None: # [condition_latents.shape[2], condition_latents.shape[2]+video_length] latent_index_lst = sorted( list( set(range(total_frames)) - set(vision_condition_latent_index_lst) ) ) latent_index = torch.LongTensor( latent_index_lst, ).to(device=device) if vision_condition_latent_index is not None: vision_condition_latent_index = vision_condition_latent_index.to( device=device ) if self.print_idx == 0: logger.debug( f"pipeline vision_condition_latent_index ={vision_condition_latent_index.shape}, {vision_condition_latent_index}" ) if latent_index is not None: latent_index = latent_index.to(device=device) if self.print_idx == 0: logger.debug( f"pipeline latent_index ={latent_index.shape}, {latent_index}" ) logger.debug(f"condition_latents={type(condition_latents)}") logger.debug(f"latent_index={type(latent_index)}") logger.debug( f"vision_condition_latent_index={type(vision_condition_latent_index)}" ) return condition_latents, latent_index, vision_condition_latent_index def prepare_controlnet_and_guidance_parameter( self, control_guidance_start, control_guidance_end ): controlnet = ( self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet ) # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance( control_guidance_end, list ): control_guidance_start = len(control_guidance_end) * [ control_guidance_start ] elif not isinstance(control_guidance_end, list) and isinstance( control_guidance_start, list ): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance( control_guidance_end, list ): mult = ( len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 ) control_guidance_start, control_guidance_end = mult * [ control_guidance_start ], mult * [control_guidance_end] return controlnet, control_guidance_start, control_guidance_end def prepare_controlnet_guess_mode(self, controlnet, guess_mode): global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions return guess_mode def prepare_controlnet_image_and_latents( self, controlnet, width, height, batch_size, num_videos_per_prompt, device, dtype, controlnet_latents=None, controlnet_condition_latents=None, control_image=None, controlnet_condition_images=None, guess_mode=False, do_classifier_free_guidance=False, ): if isinstance(controlnet, ControlNetModel): if controlnet_latents is not None: if isinstance(controlnet_latents, np.ndarray): controlnet_latents = torch.from_numpy(controlnet_latents) if controlnet_condition_latents is not None: if isinstance(controlnet_condition_latents, np.ndarray): controlnet_condition_latents = torch.from_numpy( controlnet_condition_latents ) # TODO:使用index进行concat controlnet_latents = torch.concat( [controlnet_condition_latents, controlnet_latents], dim=2 ) if not guess_mode and do_classifier_free_guidance: controlnet_latents = torch.concat([controlnet_latents] * 2, dim=0) controlnet_latents = rearrange( controlnet_latents, "b c t h w->(b t) c h w" ) controlnet_latents = controlnet_latents.to(device=device, dtype=dtype) if self.print_idx == 0: logger.debug( f"call, controlnet_latents.shape, f{controlnet_latents.shape}" ) else: # TODO: concat with index if isinstance(control_image, np.ndarray): control_image = torch.from_numpy(control_image) if controlnet_condition_images is not None: if isinstance(controlnet_condition_images, np.ndarray): controlnet_condition_images = torch.from_numpy( controlnet_condition_images ) control_image = torch.concatenate( [controlnet_condition_images, control_image], dim=2 ) control_image = self.prepare_control_image( image=control_image, width=width, height=height, batch_size=batch_size * num_videos_per_prompt, num_images_per_prompt=num_videos_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = control_image.shape[-2:] if self.print_idx == 0: logger.debug(f"call, control_image.shape , {control_image.shape}") elif isinstance(controlnet, MultiControlNetModel): control_images = [] # TODO: directly support contronet_latent instead of frames if ( controlnet_latents is not None and controlnet_condition_latents is not None ): raise NotImplementedError for i, control_image_ in enumerate(control_image): if controlnet_condition_images is not None and isinstance( controlnet_condition_images, list ): if isinstance(controlnet_condition_images[i], np.ndarray): control_image_ = np.concatenate( [controlnet_condition_images[i], control_image_], axis=2 ) control_image_ = self.prepare_control_image( image=control_image_, width=width, height=height, batch_size=batch_size * num_videos_per_prompt, num_images_per_prompt=num_videos_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) control_images.append(control_image_) control_image = control_images height, width = control_image[0].shape[-2:] else: assert False if control_image is not None: if not isinstance(control_image, list): if self.print_idx == 0: logger.debug(f"control_image shape is {control_image.shape}") else: if self.print_idx == 0: logger.debug(f"control_image shape is {control_image[0].shape}") return control_image, controlnet_latents def get_controlnet_emb( self, run_controlnet, guess_mode, do_classifier_free_guidance, latents, prompt_embeds, latent_model_input, controlnet_keep, controlnet_conditioning_scale, control_image, controlnet_latents, i, t, ): if run_controlnet and self.pose_guider is None: # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input( control_model_input, t ) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): cond_scale = [ c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i]) ] else: cond_scale = controlnet_conditioning_scale * controlnet_keep[i] control_model_input_reshape = rearrange( control_model_input, "b c t h w -> (b t) c h w" ) logger.debug( f"control_model_input_reshape={control_model_input_reshape.shape}, controlnet_prompt_embeds={controlnet_prompt_embeds.shape}" ) encoder_hidden_states_repeat = align_repeat_tensor_single_dim( controlnet_prompt_embeds, target_length=control_model_input_reshape.shape[0], dim=0, ) if self.print_idx == 0: logger.debug( f"control_model_input_reshape={control_model_input_reshape.shape}, " f"encoder_hidden_states_repeat={encoder_hidden_states_repeat.shape}, " ) down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input_reshape, t, encoder_hidden_states_repeat, controlnet_cond=control_image, controlnet_cond_latents=controlnet_latents, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, ) if self.print_idx == 0: logger.debug( f"controlnet, len(down_block_res_samples, {len(down_block_res_samples)}", ) for i_tmp, tmp in enumerate(down_block_res_samples): logger.debug( f"controlnet down_block_res_samples i={i_tmp}, down_block_res_sample={tmp.shape}" ) logger.debug( f"controlnet mid_block_res_sample, {mid_block_res_sample.shape}" ) if guess_mode and do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [ torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples ] mid_block_res_sample = torch.cat( [ torch.zeros_like(mid_block_res_sample), mid_block_res_sample, ] ) else: down_block_res_samples = None mid_block_res_sample = None return down_block_res_samples, mid_block_res_sample @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, video_length: Optional[int], prompt: Union[str, List[str]] = None, # b c t h w image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, control_image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, # b c t(1) ho wo condition_images: Optional[torch.FloatTensor] = None, condition_latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None, add_latents_noise: bool = False, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.8, num_inference_steps: int = 50, guidance_scale: float = 7.5, guidance_scale_end: float = None, guidance_scale_method: str = "linear", negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, # b c t(1) hi wi controlnet_condition_images: Optional[torch.FloatTensor] = None, # b c t(1) ho wo controlnet_condition_latents: Optional[torch.FloatTensor] = None, controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "tensor", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, need_middle_latents: bool = False, w_ind_noise: float = 0.5, initial_common_latent: Optional[torch.FloatTensor] = None, latent_index: torch.LongTensor = None, vision_condition_latent_index: torch.LongTensor = None, # noise parameters noise_type: str = "random", need_img_based_video_noise: bool = False, skip_temporal_layer: bool = False, img_weight: float = 1e-3, need_hist_match: bool = False, motion_speed: float = 8.0, refer_image: Optional[Tuple[torch.Tensor, np.array]] = None, ip_adapter_image: Optional[Tuple[torch.Tensor, np.array]] = None, refer_face_image: Optional[Tuple[torch.Tensor, np.array]] = None, ip_adapter_scale: float = 1.0, facein_scale: float = 1.0, ip_adapter_face_scale: float = 1.0, ip_adapter_face_image: Optional[Tuple[torch.Tensor, np.array]] = None, prompt_only_use_image_prompt: bool = False, # serial_denoise parameter start record_mid_video_noises: bool = False, last_mid_video_noises: List[torch.Tensor] = None, record_mid_video_latents: bool = False, last_mid_video_latents: List[torch.TensorType] = None, video_overlap: int = 1, # serial_denoise parameter end # parallel_denoise parameter start # refer to https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/src/pipelines/pipeline_pose2vid_long.py#L354 context_schedule="uniform", context_frames=12, context_stride=1, context_overlap=4, context_batch_size=1, interpolation_factor=1, # parallel_denoise parameter end ): r""" 旨在兼容text2video、text2image、img2img、video2video、是否有controlnet等的通用pipeline。目前仅不支持img2img、video2video。 支持多片段同时denoise,交叉部分加权平均 当 skip_temporal_layer 为 False 时, unet 起 video 生成作用;skip_temporal_layer为True时,unet起原image作用。 当controlnet的所有入参为None,等价于走的是text2video pipeline; 当 condition_latents、controlnet_condition_images、controlnet_condition_latents为None时,表示不走首帧条件生成的时序condition pipeline 现在没有考虑对 `num_videos_per_prompt` 的兼容性,不是1可能报错; if skip_temporal_layer is False, unet motion layer works, else unet only run text2image layers. if parameters about controlnet are None, means text2video pipeline; if ondition_latents、controlnet_condition_images、controlnet_condition_latents are None, means only run text2video without vision condition images. By now, code works well with `num_videos_per_prpmpt=1`, !=1 may be wrong. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in init, images must be passed as a list such that each element of the list can be correctly batched for input to a single controlnet. condition_latents: 与latents相对应,是Latents的时序condition,一般为首帧,b c t(1) ho wo be corresponding to latents, vision condtion latents, usually first frame, should be b c t(1) ho wo. controlnet_latents: 与image二选一,image会被转化成controlnet_latents Choose either image or controlnet_latents. If image is chosen, it will be converted to controlnet_latents. controlnet_condition_images: Optional[torch.FloatTensor]# b c t(1) ho wo,与image相对应,会和image在t通道concat一起,然后转化成 controlnet_latents b c t(1) ho wo, corresponding to image, will be concatenated along the t channel with image and then converted to controlnet_latents. controlnet_condition_latents: Optional[torch.FloatTensor]:# b c t(1) ho wo,会和 controlnet_latents 在t 通道concat一起,转化成 controlnet_latents b c t(1) ho wo will be concatenated along the t channel with controlnet_latents and converted to controlnet_latents. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. If multiple ControlNets are specified in init, you can set the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the controlnet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the controlnet stops applying. skip_temporal_layer (`bool`: default to False) 为False时,unet起video生成作用,会运行时序生成的block;skip_temporal_layer为True时,unet起原image作用,跳过时序生成的block。 need_img_based_video_noise: bool = False, 当只有首帧latents时,是否需要扩展为video noise; num_videos_per_prompt: now only support 1. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ run_controlnet = control_image is not None or controlnet_latents is not None if run_controlnet: ( controlnet, control_guidance_start, control_guidance_end, ) = self.prepare_controlnet_and_guidance_parameter( control_guidance_start=control_guidance_start, control_guidance_end=control_guidance_end, ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, control_image, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device dtype = self.unet.dtype # print("pipeline unet dtype", dtype) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 if run_controlnet: if isinstance(controlnet, MultiControlNetModel) and isinstance( controlnet_conditioning_scale, float ): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len( controlnet.nets ) guess_mode = self.prepare_controlnet_guess_mode( controlnet=controlnet, guess_mode=guess_mode, ) # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) if self.text_encoder is not None: prompt_embeds = encode_weighted_prompt( self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, # lora_scale=text_encoder_lora_scale, ) logger.debug(f"use text_encoder prepare prompt_emb={prompt_embeds.shape}") else: prompt_embeds = None if image is not None: image = self.prepare_image( image, width=width, height=height, batch_size=batch_size * num_videos_per_prompt, num_images_per_prompt=num_videos_per_prompt, device=device, dtype=dtype, ) if self.print_idx == 0: logger.debug(f"image={image.shape}") if condition_images is not None: condition_images = self.prepare_image( condition_images, width=width, height=height, batch_size=batch_size * num_videos_per_prompt, num_images_per_prompt=num_videos_per_prompt, device=device, dtype=dtype, ) if self.print_idx == 0: logger.debug(f"condition_images={condition_images.shape}") # 4. Prepare image if run_controlnet: ( control_image, controlnet_latents, ) = self.prepare_controlnet_image_and_latents( controlnet=controlnet, width=width, height=height, batch_size=batch_size, num_videos_per_prompt=num_videos_per_prompt, device=device, dtype=dtype, controlnet_condition_latents=controlnet_condition_latents, control_image=control_image, controlnet_condition_images=controlnet_condition_images, guess_mode=guess_mode, do_classifier_free_guidance=do_classifier_free_guidance, controlnet_latents=controlnet_latents, ) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) if strength and (image is not None and latents is not None): if self.print_idx == 0: logger.debug( f"prepare timesteps, with get_timesteps strength={strength}, num_inference_steps={num_inference_steps}" ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, device ) else: if self.print_idx == 0: logger.debug(f"prepare timesteps, without get_timesteps") timesteps = self.scheduler.timesteps latent_timestep = timesteps[:1].repeat( batch_size * num_videos_per_prompt ) # 6. Prepare latent variables ( condition_latents, latent_index, vision_condition_latent_index, ) = self.prepare_condition_latents_and_index( condition_images=condition_images, condition_latents=condition_latents, video_length=video_length, batch_size=batch_size, dtype=dtype, device=device, latent_index=latent_index, vision_condition_latent_index=vision_condition_latent_index, ) if vision_condition_latent_index is None: n_vision_cond = 0 else: n_vision_cond = vision_condition_latent_index.shape[0] num_channels_latents = self.unet.config.in_channels if self.print_idx == 0: logger.debug(f"pipeline controlnet, start prepare latents") latents = self.prepare_latents( batch_size=batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents, video_length=video_length, height=height, width=width, dtype=dtype, device=device, generator=generator, latents=latents, image=image, timestep=latent_timestep, w_ind_noise=w_ind_noise, initial_common_latent=initial_common_latent, noise_type=noise_type, add_latents_noise=add_latents_noise, need_img_based_video_noise=need_img_based_video_noise, condition_latents=condition_latents, img_weight=img_weight, ) if self.print_idx == 0: logger.debug(f"pipeline controlnet, finish prepare latents={latents.shape}") # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) if noise_type == "video_fusion" and "noise_type" in set( inspect.signature(self.scheduler.step).parameters.keys() ): extra_step_kwargs["w_ind_noise"] = w_ind_noise extra_step_kwargs["noise_type"] = noise_type # extra_step_kwargs["noise_offset"] = noise_offset # 7.1 Create tensor stating which controlnets to keep if run_controlnet: controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append( keeps[0] if isinstance(controlnet, ControlNetModel) else keeps ) else: controlnet_keep = None # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order if skip_temporal_layer: self.unet.set_skip_temporal_layers(True) n_timesteps = len(timesteps) guidance_scale_lst = generate_parameters_with_timesteps( start=guidance_scale, stop=guidance_scale_end, num=n_timesteps, method=guidance_scale_method, ) if self.print_idx == 0: logger.debug( f"guidance_scale_lst, {guidance_scale_method}, {guidance_scale}, {guidance_scale_end}, {guidance_scale_lst}" ) ip_adapter_image_emb = self.get_ip_adapter_image_emb( ip_adapter_image=ip_adapter_image, batch_size=batch_size, device=device, dtype=dtype, do_classifier_free_guidance=do_classifier_free_guidance, height=height, width=width, ) # 当前仅当没有ip_adapter时,按照参数 prompt_only_use_image_prompt 要求是否完全替换 image_prompt_emb # only if ip_adapter is None and prompt_only_use_image_prompt is True, use image_prompt_emb replace text_prompt if ( ip_adapter_image_emb is not None and prompt_only_use_image_prompt and not self.unet.ip_adapter_cross_attn ): prompt_embeds = ip_adapter_image_emb logger.debug(f"use ip_adapter_image_emb replace prompt_embeds") refer_face_image_emb = self.get_facein_image_emb( refer_face_image=refer_face_image, batch_size=batch_size, device=device, dtype=dtype, do_classifier_free_guidance=do_classifier_free_guidance, ) ip_adapter_face_emb = self.get_ip_adapter_face_emb( refer_face_image=ip_adapter_face_image, batch_size=batch_size, device=device, dtype=dtype, do_classifier_free_guidance=do_classifier_free_guidance, ) refer_image_vae_emb = self.get_referencenet_image_vae_emb( refer_image=refer_image, device=device, dtype=dtype, do_classifier_free_guidance=do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, batch_size=batch_size, width=width, height=height, ) if self.pose_guider is not None and control_image is not None: if self.print_idx == 0: logger.debug(f"pose_guider, controlnet_image={control_image.shape}") control_image = rearrange( control_image, " (b t) c h w->b c t h w", t=video_length ) pose_guider_emb = self.pose_guider(control_image) pose_guider_emb = rearrange(pose_guider_emb, "b c t h w-> (b t) c h w") else: pose_guider_emb = None logger.debug(f"prompt_embeds={prompt_embeds.shape}") if control_image is not None: if isinstance(control_image, list): logger.debug(f"control_imageis list, num={len(control_image)}") control_image = [ rearrange( control_image_tmp, " (b t) c h w->b c t h w", b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, ) for control_image_tmp in control_image ] else: logger.debug(f"control_image={control_image.shape}, before") control_image = rearrange( control_image, " (b t) c h w->b c t h w", b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, ) logger.debug(f"control_image={control_image.shape}, after") if controlnet_latents is not None: if isinstance(controlnet_latents, list): logger.debug( f"controlnet_latents is list, num={len(controlnet_latents)}" ) controlnet_latents = [ rearrange( controlnet_latents_tmp, " (b t) c h w->b c t h w", b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, ) for controlnet_latents_tmp in controlnet_latents ] else: logger.debug(f"controlnet_latents={controlnet_latents.shape}, before") controlnet_latents = rearrange( controlnet_latents, " (b t) c h w->b c t h w", b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, ) logger.debug(f"controlnet_latents={controlnet_latents.shape}, after") videos_mid = [] mid_video_noises = [] if record_mid_video_noises else None mid_video_latents = [] if record_mid_video_latents else None global_context = prepare_global_context( context_schedule=context_schedule, num_inference_steps=num_inference_steps, time_size=latents.shape[2], context_frames=context_frames, context_stride=context_stride, context_overlap=context_overlap, context_batch_size=context_batch_size, ) logger.debug( f"context_schedule={context_schedule}, time_size={latents.shape[2]}, context_frames={context_frames}, context_stride={context_stride}, context_overlap={context_overlap}, context_batch_size={context_batch_size}" ) logger.debug(f"global_context={global_context}") # iterative denoise with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # 使用 last_mid_video_latents 来影响初始化latent,该部分效果较差,暂留代码 # use last_mide_video_latents to affect initial latent. works bad, Temporarily reserved if i == 0: if record_mid_video_latents: mid_video_latents.append(latents[:, :, -video_overlap:]) if record_mid_video_noises: mid_video_noises.append(None) if ( last_mid_video_latents is not None and len(last_mid_video_latents) > 0 ): if self.print_idx == 1: logger.debug( f"{i}, last_mid_video_latents={last_mid_video_latents[i].shape}" ) latents = fuse_part_tensor( last_mid_video_latents[0], latents, video_overlap, weight=0.1, skip_step=0, ) noise_pred = torch.zeros( ( latents.shape[0] * (2 if do_classifier_free_guidance else 1), *latents.shape[1:], ), device=latents.device, dtype=latents.dtype, ) counter = torch.zeros( (1, 1, latents.shape[2], 1, 1), device=latents.device, dtype=latents.dtype, ) if i == 0: ( down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb, ) = self.get_referencenet_emb( refer_image_vae_emb=refer_image_vae_emb, refer_image=refer_image, device=device, dtype=dtype, do_classifier_free_guidance=do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, ip_adapter_image_emb=ip_adapter_image_emb, batch_size=batch_size, ref_timestep_int=t, ) for context in global_context: # expand the latents if we are doing classifier free guidance latents_c = torch.cat([latents[:, :, c] for c in context]) latent_index_c = ( torch.cat([latent_index[c] for c in context]) if latent_index is not None else None ) latent_model_input = latents_c.to(device).repeat( 2 if do_classifier_free_guidance else 1, 1, 1, 1, 1 ) latent_model_input = self.scheduler.scale_model_input( latent_model_input, t ) sub_latent_index_c = ( torch.LongTensor( torch.arange(latent_index_c.shape[-1]) + n_vision_cond ).to(device=latents_c.device) if latent_index is not None else None ) if condition_latents is not None: latent_model_condition = ( torch.cat([condition_latents] * 2) if do_classifier_free_guidance else latents ) if self.print_idx == 0: logger.debug( f"vision_condition_latent_index, {vision_condition_latent_index.shape}, vision_condition_latent_index" ) logger.debug( f"latent_model_condition, {latent_model_condition.shape}" ) logger.debug(f"latent_index, {latent_index_c.shape}") logger.debug( f"latent_model_input, {latent_model_input.shape}" ) logger.debug(f"sub_latent_index_c, {sub_latent_index_c}") latent_model_input = batch_concat_two_tensor_with_index( data1=latent_model_condition, data1_index=vision_condition_latent_index, data2=latent_model_input, data2_index=sub_latent_index_c, dim=2, ) if control_image is not None: if vision_condition_latent_index is not None: # 获取 vision_condition 对应的 control_imgae/control_latent 部分 # generate control_image/control_latent corresponding to vision_condition controlnet_condtion_latent_index = ( vision_condition_latent_index.clone().cpu().tolist() ) if self.print_idx == 0: logger.debug( f"context={context}, controlnet_condtion_latent_index={controlnet_condtion_latent_index}" ) controlnet_context = [ controlnet_condtion_latent_index + [c_i + n_vision_cond for c_i in c] for c in context ] else: controlnet_context = context if self.print_idx == 0: logger.debug( f"controlnet_context={controlnet_context}, latent_model_input={latent_model_input.shape}" ) if isinstance(control_image, list): control_image_c = [ torch.cat( [ control_image_tmp[:, :, c] for c in controlnet_context ] ) for control_image_tmp in control_image ] control_image_c = [ rearrange(control_image_tmp, " b c t h w-> (b t) c h w") for control_image_tmp in control_image_c ] else: control_image_c = torch.cat( [control_image[:, :, c] for c in controlnet_context] ) control_image_c = rearrange( control_image_c, " b c t h w-> (b t) c h w" ) else: control_image_c = None if controlnet_latents is not None: if vision_condition_latent_index is not None: # 获取 vision_condition 对应的 control_imgae/control_latent 部分 # generate control_image/control_latent corresponding to vision_condition controlnet_condtion_latent_index = ( vision_condition_latent_index.clone().cpu().tolist() ) if self.print_idx == 0: logger.debug( f"context={context}, controlnet_condtion_latent_index={controlnet_condtion_latent_index}" ) controlnet_context = [ controlnet_condtion_latent_index + [c_i + n_vision_cond for c_i in c] for c in context ] else: controlnet_context = context if self.print_idx == 0: logger.debug( f"controlnet_context={controlnet_context}, controlnet_latents={controlnet_latents.shape}, latent_model_input={latent_model_input.shape}," ) controlnet_latents_c = torch.cat( [controlnet_latents[:, :, c] for c in controlnet_context] ) controlnet_latents_c = rearrange( controlnet_latents_c, " b c t h w-> (b t) c h w" ) else: controlnet_latents_c = None ( down_block_res_samples, mid_block_res_sample, ) = self.get_controlnet_emb( run_controlnet=run_controlnet, guess_mode=guess_mode, do_classifier_free_guidance=do_classifier_free_guidance, latents=latents_c, prompt_embeds=prompt_embeds, latent_model_input=latent_model_input, control_image=control_image_c, controlnet_latents=controlnet_latents_c, controlnet_keep=controlnet_keep, t=t, i=i, controlnet_conditioning_scale=controlnet_conditioning_scale, ) if self.print_idx == 0: logger.debug( f"{i}, latent_model_input={latent_model_input.shape}, sub_latent_index_c={sub_latent_index_c}" f"{vision_condition_latent_index}" ) # time.sleep(10) noise_pred_c = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, sample_index=sub_latent_index_c, vision_conditon_frames_sample_index=vision_condition_latent_index, sample_frame_rate=motion_speed, down_block_refer_embs=down_block_refer_embs, mid_block_refer_emb=mid_block_refer_emb, refer_self_attn_emb=refer_self_attn_emb, vision_clip_emb=ip_adapter_image_emb, face_emb=refer_face_image_emb, ip_adapter_scale=ip_adapter_scale, facein_scale=facein_scale, ip_adapter_face_emb=ip_adapter_face_emb, ip_adapter_face_scale=ip_adapter_face_scale, do_classifier_free_guidance=do_classifier_free_guidance, pose_guider_emb=pose_guider_emb, )[0] if condition_latents is not None: noise_pred_c = batch_index_select( noise_pred_c, dim=2, index=sub_latent_index_c ).contiguous() if self.print_idx == 0: logger.debug( f"{i}, latent_model_input={latent_model_input.shape}, noise_pred_c={noise_pred_c.shape}, {len(context)}, {len(context[0])}" ) for j, c in enumerate(context): noise_pred[:, :, c] = noise_pred[:, :, c] + noise_pred_c counter[:, :, c] = counter[:, :, c] + 1 noise_pred = noise_pred / counter if ( last_mid_video_noises is not None and len(last_mid_video_noises) > 0 and i <= num_inference_steps // 2 # 是个超参数 super paramter ): if self.print_idx == 1: logger.debug( f"{i}, last_mid_video_noises={last_mid_video_noises[i].shape}" ) noise_pred = fuse_part_tensor( last_mid_video_noises[i + 1], noise_pred, video_overlap, weight=0.01, skip_step=1, ) if record_mid_video_noises: mid_video_noises.append(noise_pred[:, :, -video_overlap:]) # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale_lst[i] * ( noise_pred_text - noise_pred_uncond ) if self.print_idx == 0: logger.debug( f"before step, noise_pred={noise_pred.shape}, {noise_pred.device}, latents={latents.shape}, {latents.device}, t={t}" ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, **extra_step_kwargs, ).prev_sample if ( last_mid_video_latents is not None and len(last_mid_video_latents) > 0 and i <= 1 # 超参数, super parameter ): if self.print_idx == 1: logger.debug( f"{i}, last_mid_video_latents={last_mid_video_latents[i].shape}" ) latents = fuse_part_tensor( last_mid_video_latents[i + 1], latents, video_overlap, weight=0.1, skip_step=0, ) if record_mid_video_latents: mid_video_latents.append(latents[:, :, -video_overlap:]) if need_middle_latents is True: videos_mid.append(self.decode_latents(latents)) # call the callback, if provided if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) self.print_idx += 1 if condition_latents is not None: latents = batch_concat_two_tensor_with_index( data1=condition_latents, data1_index=vision_condition_latent_index, data2=latents, data2_index=latent_index, dim=2, ) video = self.decode_latents(latents) if skip_temporal_layer: self.unet.set_skip_temporal_layers(False) if need_hist_match: video[:, :, latent_index, :, :] = self.hist_match_with_vis_cond( batch_index_select(video, index=latent_index, dim=2), batch_index_select(video, index=vision_condition_latent_index, dim=2), ) # Convert to tensor if output_type == "tensor": videos_mid = [torch.from_numpy(x) for x in videos_mid] video = torch.from_numpy(video) else: latents = latents.cpu().numpy() if not return_dict: return ( video, latents, videos_mid, mid_video_latents, mid_video_noises, ) return VideoPipelineOutput( videos=video, latents=latents, videos_mid=videos_mid, mid_video_latents=mid_video_latents, mid_video_noises=mid_video_noises, )