Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import inspect | |
import math | |
import time | |
import warnings | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
from dataclasses import dataclass | |
from einops import rearrange, repeat | |
import PIL.Image | |
import numpy as np | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer | |
from diffusers.pipelines.controlnet.pipeline_controlnet import ( | |
StableDiffusionSafetyChecker, | |
EXAMPLE_DOC_STRING, | |
) | |
from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import ( | |
StableDiffusionControlNetImg2ImgPipeline as DiffusersStableDiffusionControlNetImg2ImgPipeline, | |
) | |
from diffusers.configuration_utils import FrozenDict | |
from diffusers.models import AutoencoderKL, ControlNetModel | |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel | |
from diffusers.pipelines.stable_diffusion.safety_checker import ( | |
StableDiffusionSafetyChecker, | |
) | |
from diffusers.schedulers import KarrasDiffusionSchedulers | |
from diffusers.utils import ( | |
deprecate, | |
logging, | |
BaseOutput, | |
replace_example_docstring, | |
) | |
from diffusers.utils.torch_utils import is_compiled_module | |
from diffusers.loaders import TextualInversionLoaderMixin | |
from diffusers.models.attention import ( | |
BasicTransformerBlock as DiffusersBasicTransformerBlock, | |
) | |
from mmcm.vision.process.correct_color import ( | |
hist_match_color_video_batch, | |
hist_match_video_bcthw, | |
) | |
from ..models.attention import BasicTransformerBlock | |
from ..models.unet_3d_condition import UNet3DConditionModel | |
from ..utils.noise_util import random_noise, video_fusion_noise | |
from ..data.data_util import ( | |
adaptive_instance_normalization, | |
align_repeat_tensor_single_dim, | |
batch_adain_conditioned_tensor, | |
batch_concat_two_tensor_with_index, | |
batch_index_select, | |
fuse_part_tensor, | |
) | |
from ..utils.text_emb_util import encode_weighted_prompt | |
from ..utils.tensor_util import his_match | |
from ..utils.timesteps_util import generate_parameters_with_timesteps | |
from .context import get_context_scheduler, prepare_global_context | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
class VideoPipelineOutput(BaseOutput): | |
videos: Union[torch.Tensor, np.ndarray] | |
latents: Union[torch.Tensor, np.ndarray] | |
videos_mid: Union[torch.Tensor, np.ndarray] | |
down_block_res_samples: Tuple[torch.FloatTensor] = None | |
mid_block_res_samples: torch.FloatTensor = None | |
up_block_res_samples: torch.FloatTensor = None | |
mid_video_latents: List[torch.FloatTensor] = None | |
mid_video_noises: List[torch.FloatTensor] = None | |
def torch_dfs(model: torch.nn.Module): | |
result = [model] | |
for child in model.children(): | |
result += torch_dfs(child) | |
return result | |
def prepare_image( | |
image, # b c t h w | |
batch_size, | |
device, | |
dtype, | |
image_processor: Callable, | |
num_images_per_prompt: int = 1, | |
width=None, | |
height=None, | |
): | |
if isinstance(image, List) and isinstance(image[0], str): | |
raise NotImplementedError | |
if isinstance(image, List) and isinstance(image[0], np.ndarray): | |
image = np.concatenate(image, axis=0) | |
if isinstance(image, np.ndarray): | |
image = torch.from_numpy(image) | |
if image.ndim == 5: | |
image = rearrange(image, "b c t h w-> (b t) c h w") | |
if height is None: | |
height = image.shape[-2] | |
if width is None: | |
width = image.shape[-1] | |
width, height = (x - x % image_processor.vae_scale_factor for x in (width, height)) | |
if height != image.shape[-2] or width != image.shape[-1]: | |
image = torch.nn.functional.interpolate( | |
image, size=(height, width), mode="bilinear" | |
) | |
image = image.to(dtype=torch.float32) / 255.0 | |
do_normalize = image_processor.config.do_normalize | |
if image.min() < 0: | |
warnings.warn( | |
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " | |
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", | |
FutureWarning, | |
) | |
do_normalize = False | |
if do_normalize: | |
image = image_processor.normalize(image) | |
image_batch_size = image.shape[0] | |
if image_batch_size == 1: | |
repeat_by = batch_size | |
else: | |
# image batch size is the same as prompt batch size | |
repeat_by = num_images_per_prompt | |
image = image.repeat_interleave(repeat_by, dim=0) | |
image = image.to(device=device, dtype=dtype) | |
return image | |
class MusevControlNetPipeline( | |
DiffusersStableDiffusionControlNetImg2ImgPipeline, TextualInversionLoaderMixin | |
): | |
""" | |
a union diffusers pipeline, support | |
1. text2image model only, or text2video model, by setting skip_temporal_layer | |
2. text2video, image2video, video2video; | |
3. multi controlnet | |
4. IPAdapter | |
5. referencenet | |
6. IPAdapterFaceID | |
""" | |
_optional_components = [ | |
"safety_checker", | |
"feature_extractor", | |
] | |
print_idx = 0 | |
def __init__( | |
self, | |
vae: AutoencoderKL, | |
unet: UNet3DConditionModel, | |
scheduler: KarrasDiffusionSchedulers, | |
controlnet: ControlNetModel | |
| List[ControlNetModel] | |
| Tuple[ControlNetModel] | |
| MultiControlNetModel, | |
text_encoder: CLIPTextModel, | |
tokenizer: CLIPTokenizer, | |
safety_checker: StableDiffusionSafetyChecker, | |
feature_extractor: CLIPImageProcessor, | |
# | MultiControlNetModel = None, | |
# text_encoder: CLIPTextModel = None, | |
# tokenizer: CLIPTokenizer = None, | |
# safety_checker: StableDiffusionSafetyChecker = None, | |
# feature_extractor: CLIPImageProcessor = None, | |
requires_safety_checker: bool = False, | |
referencenet: nn.Module = None, | |
vision_clip_extractor: nn.Module = None, | |
ip_adapter_image_proj: nn.Module = None, | |
face_emb_extractor: nn.Module = None, | |
facein_image_proj: nn.Module = None, | |
ip_adapter_face_emb_extractor: nn.Module = None, | |
ip_adapter_face_image_proj: nn.Module = None, | |
pose_guider: nn.Module = None, | |
): | |
super().__init__( | |
vae, | |
text_encoder, | |
tokenizer, | |
unet, | |
controlnet, | |
scheduler, | |
safety_checker, | |
feature_extractor, | |
requires_safety_checker, | |
) | |
self.referencenet = referencenet | |
# ip_adapter | |
if isinstance(vision_clip_extractor, nn.Module): | |
vision_clip_extractor.to(dtype=self.unet.dtype, device=self.unet.device) | |
self.vision_clip_extractor = vision_clip_extractor | |
if isinstance(ip_adapter_image_proj, nn.Module): | |
ip_adapter_image_proj.to(dtype=self.unet.dtype, device=self.unet.device) | |
self.ip_adapter_image_proj = ip_adapter_image_proj | |
# facein | |
if isinstance(face_emb_extractor, nn.Module): | |
face_emb_extractor.to(dtype=self.unet.dtype, device=self.unet.device) | |
self.face_emb_extractor = face_emb_extractor | |
if isinstance(facein_image_proj, nn.Module): | |
facein_image_proj.to(dtype=self.unet.dtype, device=self.unet.device) | |
self.facein_image_proj = facein_image_proj | |
# ip_adapter_face | |
if isinstance(ip_adapter_face_emb_extractor, nn.Module): | |
ip_adapter_face_emb_extractor.to( | |
dtype=self.unet.dtype, device=self.unet.device | |
) | |
self.ip_adapter_face_emb_extractor = ip_adapter_face_emb_extractor | |
if isinstance(ip_adapter_face_image_proj, nn.Module): | |
ip_adapter_face_image_proj.to( | |
dtype=self.unet.dtype, device=self.unet.device | |
) | |
self.ip_adapter_face_image_proj = ip_adapter_face_image_proj | |
if isinstance(pose_guider, nn.Module): | |
pose_guider.to(dtype=self.unet.dtype, device=self.unet.device) | |
self.pose_guider = pose_guider | |
def decode_latents(self, latents): | |
batch_size = latents.shape[0] | |
latents = rearrange(latents, "b c f h w -> (b f) c h w") | |
video = super().decode_latents(latents=latents) | |
video = rearrange(video, "(b f) h w c -> b c f h w", b=batch_size) | |
return video | |
def prepare_latents( | |
self, | |
batch_size: int, | |
num_channels_latents: int, | |
video_length: int, | |
height: int, | |
width: int, | |
dtype: torch.dtype, | |
device: torch.device, | |
generator: torch.Generator, | |
latents: torch.Tensor = None, | |
w_ind_noise: float = 0.5, | |
image: torch.Tensor = None, | |
timestep: int = None, | |
initial_common_latent: torch.Tensor = None, | |
noise_type: str = "random", | |
add_latents_noise: bool = False, | |
need_img_based_video_noise: bool = False, | |
condition_latents: torch.Tensor = None, | |
img_weight=1e-3, | |
) -> torch.Tensor: | |
""" | |
支持多种情况下的latens: | |
img_based_latents: 当Image t=1,latents=None时,使用image赋值到shape,然后加噪;适用于text2video、middle2video。 | |
video_based_latents:image =shape或Latents!=None时,加噪,适用于video2video; | |
noise_latents:当image 和latents都为None时,生成随机噪声,适用于text2video | |
support multi latents condition: | |
img_based_latents: when Image t=1, latents=None, use image to assign to shape, then add noise; suitable for text2video, middle2video. | |
video_based_latents: image =shape or Latents!=None, add noise, suitable for video2video; | |
noise_laten: when image and latents are both None, generate random noise, suitable for text2video | |
Args: | |
batch_size (int): _description_ | |
num_channels_latents (int): _description_ | |
video_length (int): _description_ | |
height (int): _description_ | |
width (int): _description_ | |
dtype (torch.dtype): _description_ | |
device (torch.device): _description_ | |
generator (torch.Generator): _description_ | |
latents (torch.Tensor, optional): _description_. Defaults to None. | |
w_ind_noise (float, optional): _description_. Defaults to 0.5. | |
image (torch.Tensor, optional): _description_. Defaults to None. | |
timestep (int, optional): _description_. Defaults to None. | |
initial_common_latent (torch.Tensor, optional): _description_. Defaults to None. | |
noise_type (str, optional): _description_. Defaults to "random". | |
add_latents_noise (bool, optional): _description_. Defaults to False. | |
need_img_based_video_noise (bool, optional): _description_. Defaults to False. | |
condition_latents (torch.Tensor, optional): _description_. Defaults to None. | |
img_weight (_type_, optional): _description_. Defaults to 1e-3. | |
Raises: | |
ValueError: _description_ | |
ValueError: _description_ | |
ValueError: _description_ | |
Returns: | |
torch.Tensor: latents | |
""" | |
# ref https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py#L691 | |
# ref https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet.py#L659 | |
shape = ( | |
batch_size, | |
num_channels_latents, | |
video_length, | |
height // self.vae_scale_factor, | |
width // self.vae_scale_factor, | |
) | |
if latents is None or (latents is not None and add_latents_noise): | |
if noise_type == "random": | |
noise = random_noise( | |
shape=shape, dtype=dtype, device=device, generator=generator | |
) | |
elif noise_type == "video_fusion": | |
noise = video_fusion_noise( | |
shape=shape, | |
dtype=dtype, | |
device=device, | |
generator=generator, | |
w_ind_noise=w_ind_noise, | |
initial_common_noise=initial_common_latent, | |
) | |
if ( | |
need_img_based_video_noise | |
and condition_latents is not None | |
and image is None | |
and latents is None | |
): | |
if self.print_idx == 0: | |
logger.debug( | |
( | |
f"need_img_based_video_noise, condition_latents={condition_latents.shape}," | |
f"batch_size={batch_size}, noise={noise.shape}, video_length={video_length}" | |
) | |
) | |
condition_latents = condition_latents.mean(dim=2, keepdim=True) | |
condition_latents = repeat( | |
condition_latents, "b c t h w->b c (t x) h w", x=video_length | |
) | |
noise = ( | |
img_weight**0.5 * condition_latents | |
+ (1 - img_weight) ** 0.5 * noise | |
) | |
if self.print_idx == 0: | |
logger.debug(f"noise={noise.shape}") | |
if image is not None: | |
if image.ndim == 5: | |
image = rearrange(image, "b c t h w->(b t) c h w") | |
image = image.to(device=device, dtype=dtype) | |
if isinstance(generator, list) and len(generator) != batch_size: | |
raise ValueError( | |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
) | |
if isinstance(generator, list): | |
init_latents = [ | |
# self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) | |
self.vae.encode(image[i : i + 1]).latent_dist.mean | |
for i in range(batch_size) | |
] | |
init_latents = torch.cat(init_latents, dim=0) | |
else: | |
# init_latents = self.vae.encode(image).latent_dist.sample(generator) | |
init_latents = self.vae.encode(image).latent_dist.mean | |
init_latents = self.vae.config.scaling_factor * init_latents | |
# scale the initial noise by the standard deviation required by the scheduler | |
if ( | |
batch_size > init_latents.shape[0] | |
and batch_size % init_latents.shape[0] == 0 | |
): | |
# expand init_latents for batch_size | |
deprecation_message = ( | |
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" | |
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note" | |
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" | |
" your script to pass as many initial images as text prompts to suppress this warning." | |
) | |
deprecate( | |
"len(prompt) != len(image)", | |
"1.0.0", | |
deprecation_message, | |
standard_warn=False, | |
) | |
additional_image_per_prompt = batch_size // init_latents.shape[0] | |
init_latents = torch.cat( | |
[init_latents] * additional_image_per_prompt, dim=0 | |
) | |
elif ( | |
batch_size > init_latents.shape[0] | |
and batch_size % init_latents.shape[0] != 0 | |
): | |
raise ValueError( | |
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." | |
) | |
else: | |
init_latents = torch.cat([init_latents], dim=0) | |
if init_latents.shape[2] != shape[3] and init_latents.shape[3] != shape[4]: | |
init_latents = torch.nn.functional.interpolate( | |
init_latents, | |
size=(shape[3], shape[4]), | |
mode="bilinear", | |
) | |
init_latents = rearrange( | |
init_latents, "(b t) c h w-> b c t h w", t=video_length | |
) | |
if self.print_idx == 0: | |
logger.debug(f"init_latensts={init_latents.shape}") | |
if latents is None: | |
if image is None: | |
latents = noise * self.scheduler.init_noise_sigma | |
else: | |
if self.print_idx == 0: | |
logger.debug(f"prepare latents, image is not None") | |
latents = self.scheduler.add_noise(init_latents, noise, timestep) | |
else: | |
if isinstance(latents, np.ndarray): | |
latents = torch.from_numpy(latents) | |
latents = latents.to(device=device, dtype=dtype) | |
if add_latents_noise: | |
latents = self.scheduler.add_noise(latents, noise, timestep) | |
else: | |
latents = latents * self.scheduler.init_noise_sigma | |
if latents.shape != shape: | |
raise ValueError( | |
f"Unexpected latents shape, got {latents.shape}, expected {shape}" | |
) | |
latents = latents.to(device, dtype=dtype) | |
return latents | |
def prepare_image( | |
self, | |
image, # b c t h w | |
batch_size, | |
num_images_per_prompt, | |
device, | |
dtype, | |
width=None, | |
height=None, | |
): | |
return prepare_image( | |
image=image, | |
batch_size=batch_size, | |
num_images_per_prompt=num_images_per_prompt, | |
device=device, | |
dtype=dtype, | |
width=width, | |
height=height, | |
image_processor=self.image_processor, | |
) | |
def prepare_control_image( | |
self, | |
image, # b c t h w | |
width, | |
height, | |
batch_size, | |
num_images_per_prompt, | |
device, | |
dtype, | |
do_classifier_free_guidance=False, | |
guess_mode=False, | |
): | |
image = prepare_image( | |
image=image, | |
batch_size=batch_size, | |
num_images_per_prompt=num_images_per_prompt, | |
device=device, | |
dtype=dtype, | |
width=width, | |
height=height, | |
image_processor=self.control_image_processor, | |
) | |
if do_classifier_free_guidance and not guess_mode: | |
image = torch.cat([image] * 2) | |
return image | |
def check_inputs( | |
self, | |
prompt, | |
image, | |
callback_steps, | |
negative_prompt=None, | |
prompt_embeds=None, | |
negative_prompt_embeds=None, | |
controlnet_conditioning_scale=1, | |
control_guidance_start=0, | |
control_guidance_end=1, | |
): | |
# TODO: to implement | |
if image is not None: | |
return super().check_inputs( | |
prompt, | |
image, | |
callback_steps, | |
negative_prompt, | |
prompt_embeds, | |
negative_prompt_embeds, | |
controlnet_conditioning_scale, | |
control_guidance_start, | |
control_guidance_end, | |
) | |
def hist_match_with_vis_cond( | |
self, video: np.ndarray, target: np.ndarray | |
) -> np.ndarray: | |
""" | |
video: b c t1 h w | |
target: b c t2(=1) h w | |
""" | |
video = hist_match_video_bcthw(video, target, value=255.0) | |
return video | |
def get_facein_image_emb( | |
self, refer_face_image, device, dtype, batch_size, do_classifier_free_guidance | |
): | |
# refer_face_image and its face_emb | |
if self.print_idx == 0: | |
logger.debug( | |
f"face_emb_extractor={type(self.face_emb_extractor)}, facein_image_proj={type(self.facein_image_proj)}, refer_face_image={type(refer_face_image)}, " | |
) | |
if ( | |
self.face_emb_extractor is not None | |
and self.facein_image_proj is not None | |
and refer_face_image is not None | |
): | |
if self.print_idx == 0: | |
logger.debug(f"refer_face_image={refer_face_image.shape}") | |
if isinstance(refer_face_image, np.ndarray): | |
refer_face_image = torch.from_numpy(refer_face_image) | |
refer_face_image_facein = refer_face_image | |
n_refer_face_image = refer_face_image_facein.shape[2] | |
refer_face_image_facein = rearrange( | |
refer_face_image, "b c t h w-> (b t) h w c" | |
) | |
# refer_face_image_emb: bt d或者 bt h w d | |
( | |
refer_face_image_emb, | |
refer_align_face_image, | |
) = self.face_emb_extractor.extract_images( | |
refer_face_image_facein, return_type="torch" | |
) | |
refer_face_image_emb = refer_face_image_emb.to(device=device, dtype=dtype) | |
if self.print_idx == 0: | |
logger.debug(f"refer_face_image_emb={refer_face_image_emb.shape}") | |
if refer_face_image_emb.shape == 2: | |
refer_face_image_emb = rearrange(refer_face_image_emb, "bt d-> bt 1 d") | |
elif refer_face_image_emb.shape == 4: | |
refer_face_image_emb = rearrange( | |
refer_face_image_emb, "bt h w d-> bt (h w) d" | |
) | |
refer_face_image_emb_bk = refer_face_image_emb | |
refer_face_image_emb = self.facein_image_proj(refer_face_image_emb) | |
# Todo:当前不支持 IPAdapterPlus的vision_clip的输出 | |
refer_face_image_emb = rearrange( | |
refer_face_image_emb, | |
"(b t) n q-> b (t n) q", | |
t=n_refer_face_image, | |
) | |
refer_face_image_emb = align_repeat_tensor_single_dim( | |
refer_face_image_emb, target_length=batch_size, dim=0 | |
) | |
if do_classifier_free_guidance: | |
# TODO:固定特征,有优化空间 | |
# TODO: fix the feature, there is optimization space | |
uncond_refer_face_image_emb = self.facein_image_proj( | |
torch.zeros_like(refer_face_image_emb_bk).to( | |
device=device, dtype=dtype | |
) | |
) | |
# Todo:当前可能不支持 IPAdapterPlus的vision_clip的输出 | |
# TODO: do not support IPAdapterPlus's vision_clip's output | |
uncond_refer_face_image_emb = rearrange( | |
uncond_refer_face_image_emb, | |
"(b t) n q-> b (t n) q", | |
t=n_refer_face_image, | |
) | |
uncond_refer_face_image_emb = align_repeat_tensor_single_dim( | |
uncond_refer_face_image_emb, target_length=batch_size, dim=0 | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"uncond_refer_face_image_emb, {uncond_refer_face_image_emb.shape}" | |
) | |
logger.debug(f"refer_face_image_emb, {refer_face_image_emb.shape}") | |
refer_face_image_emb = torch.concat( | |
[ | |
uncond_refer_face_image_emb, | |
refer_face_image_emb, | |
], | |
) | |
else: | |
refer_face_image_emb = None | |
if self.print_idx == 0: | |
logger.debug(f"refer_face_image_emb={type(refer_face_image_emb)}") | |
return refer_face_image_emb | |
def get_ip_adapter_face_emb( | |
self, refer_face_image, device, dtype, batch_size, do_classifier_free_guidance | |
): | |
# refer_face_image and its ip_adapter_face_emb | |
if self.print_idx == 0: | |
logger.debug( | |
f"face_emb_extractor={type(self.face_emb_extractor)}, ip_adapter__image_proj={type(self.facein_image_proj)}, refer_face_image={type(refer_face_image)}, " | |
) | |
if ( | |
self.ip_adapter_face_emb_extractor is not None | |
and self.ip_adapter_face_image_proj is not None | |
and refer_face_image is not None | |
): | |
if self.print_idx == 0: | |
logger.debug(f"refer_face_image={refer_face_image.shape}") | |
if isinstance(refer_face_image, np.ndarray): | |
refer_face_image = torch.from_numpy(refer_face_image) | |
refer_ip_adapter_face_image = refer_face_image | |
n_refer_face_image = refer_ip_adapter_face_image.shape[2] | |
refer_ip_adapter_face_image = rearrange( | |
refer_ip_adapter_face_image, "b c t h w-> (b t) h w c" | |
) | |
# refer_face_image_emb: bt d or bt h w d | |
( | |
refer_face_image_emb, | |
refer_align_face_image, | |
) = self.ip_adapter_face_emb_extractor.extract_images( | |
refer_ip_adapter_face_image, return_type="torch" | |
) | |
refer_face_image_emb = refer_face_image_emb.to(device=device, dtype=dtype) | |
if self.print_idx == 0: | |
logger.debug(f"refer_face_image_emb={refer_face_image_emb.shape}") | |
if refer_face_image_emb.shape == 2: | |
refer_face_image_emb = rearrange(refer_face_image_emb, "bt d-> bt 1 d") | |
elif refer_face_image_emb.shape == 4: | |
refer_face_image_emb = rearrange( | |
refer_face_image_emb, "bt h w d-> bt (h w) d" | |
) | |
refer_face_image_emb_bk = refer_face_image_emb | |
refer_face_image_emb = self.ip_adapter_face_image_proj(refer_face_image_emb) | |
refer_face_image_emb = rearrange( | |
refer_face_image_emb, | |
"(b t) n q-> b (t n) q", | |
t=n_refer_face_image, | |
) | |
refer_face_image_emb = align_repeat_tensor_single_dim( | |
refer_face_image_emb, target_length=batch_size, dim=0 | |
) | |
if do_classifier_free_guidance: | |
# TODO:固定特征,有优化空间 | |
# TODO: fix the feature, there is optimization space | |
uncond_refer_face_image_emb = self.ip_adapter_face_image_proj( | |
torch.zeros_like(refer_face_image_emb_bk).to( | |
device=device, dtype=dtype | |
) | |
) | |
# TODO: 当前可能不支持 IPAdapterPlus的vision_clip的输出 | |
# TODO: do not support IPAdapterPlus's vision_clip's output | |
uncond_refer_face_image_emb = rearrange( | |
uncond_refer_face_image_emb, | |
"(b t) n q-> b (t n) q", | |
t=n_refer_face_image, | |
) | |
uncond_refer_face_image_emb = align_repeat_tensor_single_dim( | |
uncond_refer_face_image_emb, target_length=batch_size, dim=0 | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"uncond_refer_face_image_emb, {uncond_refer_face_image_emb.shape}" | |
) | |
logger.debug(f"refer_face_image_emb, {refer_face_image_emb.shape}") | |
refer_face_image_emb = torch.concat( | |
[ | |
uncond_refer_face_image_emb, | |
refer_face_image_emb, | |
], | |
) | |
else: | |
refer_face_image_emb = None | |
if self.print_idx == 0: | |
logger.debug(f"ip_adapter_face_emb={type(refer_face_image_emb)}") | |
return refer_face_image_emb | |
def get_ip_adapter_image_emb( | |
self, | |
ip_adapter_image, | |
device, | |
dtype, | |
batch_size, | |
do_classifier_free_guidance, | |
height, | |
width, | |
): | |
# refer_image vision_clip and its ipadapter_emb | |
if self.print_idx == 0: | |
logger.debug( | |
f"vision_clip_extractor={type(self.vision_clip_extractor)}," | |
f"ip_adapter_image_proj={type(self.ip_adapter_image_proj)}," | |
f"ip_adapter_image={type(ip_adapter_image)}," | |
) | |
if self.vision_clip_extractor is not None and ip_adapter_image is not None: | |
if self.print_idx == 0: | |
logger.debug(f"ip_adapter_image={ip_adapter_image.shape}") | |
if isinstance(ip_adapter_image, np.ndarray): | |
ip_adapter_image = torch.from_numpy(ip_adapter_image) | |
# ip_adapter_image = ip_adapter_image.to(device=device, dtype=dtype) | |
n_ip_adapter_image = ip_adapter_image.shape[2] | |
ip_adapter_image = rearrange(ip_adapter_image, "b c t h w-> (b t) h w c") | |
ip_adapter_image_emb = self.vision_clip_extractor.extract_images( | |
ip_adapter_image, | |
target_height=height, | |
target_width=width, | |
return_type="torch", | |
) | |
if ip_adapter_image_emb.ndim == 2: | |
ip_adapter_image_emb = rearrange(ip_adapter_image_emb, "b q-> b 1 q") | |
ip_adapter_image_emb_bk = ip_adapter_image_emb | |
# 存在只需要image_prompt、但不需要 proj的场景,如使用image_prompt替代text_prompt | |
# There are scenarios where only image_prompt is needed, but proj is not needed, such as using image_prompt instead of text_prompt | |
if self.ip_adapter_image_proj is not None: | |
logger.debug(f"ip_adapter_image_proj is None, ") | |
ip_adapter_image_emb = self.ip_adapter_image_proj(ip_adapter_image_emb) | |
# TODO: 当前不支持 IPAdapterPlus的vision_clip的输出 | |
# TODO: do not support IPAdapterPlus's vision_clip's output | |
ip_adapter_image_emb = rearrange( | |
ip_adapter_image_emb, | |
"(b t) n q-> b (t n) q", | |
t=n_ip_adapter_image, | |
) | |
ip_adapter_image_emb = align_repeat_tensor_single_dim( | |
ip_adapter_image_emb, target_length=batch_size, dim=0 | |
) | |
if do_classifier_free_guidance: | |
# TODO:固定特征,有优化空间 | |
# TODO: fix the feature, there is optimization space | |
if self.ip_adapter_image_proj is not None: | |
uncond_ip_adapter_image_emb = self.ip_adapter_image_proj( | |
torch.zeros_like(ip_adapter_image_emb_bk).to( | |
device=device, dtype=dtype | |
) | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"uncond_ip_adapter_image_emb use ip_adapter_image_proj(zero_like)" | |
) | |
else: | |
uncond_ip_adapter_image_emb = torch.zeros_like(ip_adapter_image_emb) | |
if self.print_idx == 0: | |
logger.debug(f"uncond_ip_adapter_image_emb use zero_like") | |
# TODO:当前可能不支持 IPAdapterPlus的vision_clip的输出 | |
# TODO: do not support IPAdapterPlus's vision_clip's output | |
uncond_ip_adapter_image_emb = rearrange( | |
uncond_ip_adapter_image_emb, | |
"(b t) n q-> b (t n) q", | |
t=n_ip_adapter_image, | |
) | |
uncond_ip_adapter_image_emb = align_repeat_tensor_single_dim( | |
uncond_ip_adapter_image_emb, target_length=batch_size, dim=0 | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"uncond_ip_adapter_image_emb, {uncond_ip_adapter_image_emb.shape}" | |
) | |
logger.debug(f"ip_adapter_image_emb, {ip_adapter_image_emb.shape}") | |
# uncond_ip_adapter_image_emb = torch.zeros_like(ip_adapter_image_emb) | |
ip_adapter_image_emb = torch.concat( | |
[ | |
uncond_ip_adapter_image_emb, | |
ip_adapter_image_emb, | |
], | |
) | |
else: | |
ip_adapter_image_emb = None | |
if self.print_idx == 0: | |
logger.debug(f"ip_adapter_image_emb={type(ip_adapter_image_emb)}") | |
return ip_adapter_image_emb | |
def get_referencenet_image_vae_emb( | |
self, | |
refer_image, | |
batch_size, | |
num_videos_per_prompt, | |
device, | |
dtype, | |
do_classifier_free_guidance, | |
width: int = None, | |
height: int = None, | |
): | |
# prepare_referencenet_emb | |
if self.print_idx == 0: | |
logger.debug( | |
f"referencenet={type(self.referencenet)}, refer_image={type(refer_image)}" | |
) | |
if self.referencenet is not None and refer_image is not None: | |
n_refer_image = refer_image.shape[2] | |
refer_image_vae = self.prepare_image( | |
refer_image, | |
batch_size=batch_size * num_videos_per_prompt, | |
num_images_per_prompt=num_videos_per_prompt, | |
device=device, | |
dtype=dtype, | |
width=width, | |
height=height, | |
) | |
# ref_hidden_states = self.vae.encode(refer_image_vae).latent_dist.sample() | |
refer_image_vae_emb = self.vae.encode(refer_image_vae).latent_dist.mean | |
refer_image_vae_emb = self.vae.config.scaling_factor * refer_image_vae_emb | |
logger.debug(f"refer_image_vae_emb={refer_image_vae_emb.shape}") | |
if do_classifier_free_guidance: | |
# 1. zeros_like image | |
# uncond_refer_image_vae_emb = self.vae.encode( | |
# torch.zeros_like(refer_image_vae) | |
# ).latent_dist.mean | |
# uncond_refer_image_vae_emb = ( | |
# self.vae.config.scaling_factor * uncond_refer_image_vae_emb | |
# ) | |
# 2. zeros_like image vae emb | |
# uncond_refer_image_vae_emb = torch.zeros_like(refer_image_vae_emb) | |
# uncond_refer_image_vae_emb = rearrange( | |
# uncond_refer_image_vae_emb, | |
# "(b t) c h w-> b c t h w", | |
# t=n_refer_image, | |
# ) | |
# refer_image_vae_emb = rearrange( | |
# refer_image_vae_emb, "(b t) c h w-> b c t h w", t=n_refer_image | |
# ) | |
# refer_image_vae_emb = torch.concat( | |
# [uncond_refer_image_vae_emb, refer_image_vae_emb], dim=0 | |
# ) | |
# refer_image_vae_emb = rearrange( | |
# refer_image_vae_emb, "b c t h w-> (b t) c h w" | |
# ) | |
# logger.debug(f"refer_image_vae_emb={refer_image_vae_emb.shape}") | |
# 3. uncond_refer_image_vae_emb = refer_image_vae_emb | |
uncond_refer_image_vae_emb = refer_image_vae_emb | |
uncond_refer_image_vae_emb = rearrange( | |
uncond_refer_image_vae_emb, | |
"(b t) c h w-> b c t h w", | |
t=n_refer_image, | |
) | |
refer_image_vae_emb = rearrange( | |
refer_image_vae_emb, "(b t) c h w-> b c t h w", t=n_refer_image | |
) | |
refer_image_vae_emb = torch.concat( | |
[uncond_refer_image_vae_emb, refer_image_vae_emb], dim=0 | |
) | |
refer_image_vae_emb = rearrange( | |
refer_image_vae_emb, "b c t h w-> (b t) c h w" | |
) | |
logger.debug(f"refer_image_vae_emb={refer_image_vae_emb.shape}") | |
else: | |
refer_image_vae_emb = None | |
return refer_image_vae_emb | |
def get_referencenet_emb( | |
self, | |
refer_image_vae_emb, | |
refer_image, | |
batch_size, | |
num_videos_per_prompt, | |
device, | |
dtype, | |
ip_adapter_image_emb, | |
do_classifier_free_guidance, | |
prompt_embeds, | |
ref_timestep_int: int = 0, | |
): | |
# prepare_referencenet_emb | |
if self.print_idx == 0: | |
logger.debug( | |
f"referencenet={type(self.referencenet)}, refer_image={type(refer_image)}" | |
) | |
if ( | |
self.referencenet is not None | |
and refer_image_vae_emb is not None | |
and refer_image is not None | |
): | |
n_refer_image = refer_image.shape[2] | |
# ref_timestep = ( | |
# torch.ones((refer_image_vae_emb.shape[0],), device=device) | |
# * ref_timestep_int | |
# ) | |
ref_timestep = torch.zeros_like(ref_timestep_int) | |
# referencenet 优先使用 ip_adapter 中图像提取到的 clip_vision_emb | |
if ip_adapter_image_emb is not None: | |
refer_prompt_embeds = ip_adapter_image_emb | |
else: | |
refer_prompt_embeds = prompt_embeds | |
if self.print_idx == 0: | |
logger.debug( | |
f"use referencenet: n_refer_image={n_refer_image}, refer_image_vae_emb={refer_image_vae_emb.shape}, ref_timestep={ref_timestep.shape}" | |
) | |
if prompt_embeds is not None: | |
logger.debug(f"prompt_embeds={prompt_embeds.shape},") | |
# refer_image_vae_emb = self.scheduler.scale_model_input( | |
# refer_image_vae_emb, ref_timestep | |
# ) | |
# self.scheduler._step_index = None | |
# self.scheduler.is_scale_input_called = False | |
referencenet_params = { | |
"sample": refer_image_vae_emb, | |
"encoder_hidden_states": refer_prompt_embeds, | |
"timestep": ref_timestep, | |
"num_frames": n_refer_image, | |
"return_ndim": 5, | |
} | |
( | |
down_block_refer_embs, | |
mid_block_refer_emb, | |
refer_self_attn_emb, | |
) = self.referencenet(**referencenet_params) | |
# many ways to prepare negative referencenet emb | |
# mode 1 | |
# zero shape like ref_image | |
# if do_classifier_free_guidance: | |
# # mode 2: | |
# # if down_block_refer_embs is not None: | |
# # down_block_refer_embs = [ | |
# # torch.cat([x] * 2) for x in down_block_refer_embs | |
# # ] | |
# # if mid_block_refer_emb is not None: | |
# # mid_block_refer_emb = torch.cat([mid_block_refer_emb] * 2) | |
# # if refer_self_attn_emb is not None: | |
# # refer_self_attn_emb = [ | |
# # torch.cat([x] * 2) for x in refer_self_attn_emb | |
# # ] | |
# # mode 3 | |
# if down_block_refer_embs is not None: | |
# down_block_refer_embs = [ | |
# torch.cat([torch.zeros_like(x), x]) | |
# for x in down_block_refer_embs | |
# ] | |
# if mid_block_refer_emb is not None: | |
# mid_block_refer_emb = torch.cat( | |
# [torch.zeros_like(mid_block_refer_emb), mid_block_refer_emb] * 2 | |
# ) | |
# if refer_self_attn_emb is not None: | |
# refer_self_attn_emb = [ | |
# torch.cat([torch.zeros_like(x), x]) for x in refer_self_attn_emb | |
# ] | |
else: | |
down_block_refer_embs = None | |
mid_block_refer_emb = None | |
refer_self_attn_emb = None | |
if self.print_idx == 0: | |
logger.debug(f"down_block_refer_embs={type(down_block_refer_embs)}") | |
logger.debug(f"mid_block_refer_emb={type(mid_block_refer_emb)}") | |
logger.debug(f"refer_self_attn_emb={type(refer_self_attn_emb)}") | |
return down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb | |
def prepare_condition_latents_and_index( | |
self, | |
condition_images, | |
condition_latents, | |
video_length, | |
batch_size, | |
dtype, | |
device, | |
latent_index, | |
vision_condition_latent_index, | |
): | |
# prepare condition_latents | |
if condition_images is not None and condition_latents is None: | |
# condition_latents = self.vae.encode(condition_images).latent_dist.sample() | |
condition_latents = self.vae.encode(condition_images).latent_dist.mean | |
condition_latents = self.vae.config.scaling_factor * condition_latents | |
condition_latents = rearrange( | |
condition_latents, "(b t) c h w-> b c t h w", b=batch_size | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"condition_latents from condition_images, shape is condition_latents={condition_latents.shape}", | |
) | |
if condition_latents is not None: | |
total_frames = condition_latents.shape[2] + video_length | |
if isinstance(condition_latents, np.ndarray): | |
condition_latents = torch.from_numpy(condition_latents) | |
condition_latents = condition_latents.to(dtype=dtype, device=device) | |
# if condition is None, mean condition_latents head, generated video is tail | |
if vision_condition_latent_index is not None: | |
# vision_condition_latent_index should be list, whose length is condition_latents.shape[2] | |
# -1 -> will be converted to condition_latents.shape[2]+video_length | |
vision_condition_latent_index_lst = [ | |
i_v if i_v != -1 else total_frames - 1 | |
for i_v in vision_condition_latent_index | |
] | |
vision_condition_latent_index = torch.LongTensor( | |
vision_condition_latent_index_lst, | |
).to(device=device) | |
if self.print_idx == 0: | |
logger.debug( | |
f"vision_condition_latent_index {type(vision_condition_latent_index)}, {vision_condition_latent_index}" | |
) | |
else: | |
# [0, condition_latents.shape[2]] | |
vision_condition_latent_index = torch.arange( | |
condition_latents.shape[2], dtype=torch.long, device=device | |
) | |
vision_condition_latent_index_lst = ( | |
vision_condition_latent_index.tolist() | |
) | |
if latent_index is None: | |
# [condition_latents.shape[2], condition_latents.shape[2]+video_length] | |
latent_index_lst = sorted( | |
list( | |
set(range(total_frames)) | |
- set(vision_condition_latent_index_lst) | |
) | |
) | |
latent_index = torch.LongTensor( | |
latent_index_lst, | |
).to(device=device) | |
if vision_condition_latent_index is not None: | |
vision_condition_latent_index = vision_condition_latent_index.to( | |
device=device | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"pipeline vision_condition_latent_index ={vision_condition_latent_index.shape}, {vision_condition_latent_index}" | |
) | |
if latent_index is not None: | |
latent_index = latent_index.to(device=device) | |
if self.print_idx == 0: | |
logger.debug( | |
f"pipeline latent_index ={latent_index.shape}, {latent_index}" | |
) | |
logger.debug(f"condition_latents={type(condition_latents)}") | |
logger.debug(f"latent_index={type(latent_index)}") | |
logger.debug( | |
f"vision_condition_latent_index={type(vision_condition_latent_index)}" | |
) | |
return condition_latents, latent_index, vision_condition_latent_index | |
def prepare_controlnet_and_guidance_parameter( | |
self, control_guidance_start, control_guidance_end | |
): | |
controlnet = ( | |
self.controlnet._orig_mod | |
if is_compiled_module(self.controlnet) | |
else self.controlnet | |
) | |
# align format for control guidance | |
if not isinstance(control_guidance_start, list) and isinstance( | |
control_guidance_end, list | |
): | |
control_guidance_start = len(control_guidance_end) * [ | |
control_guidance_start | |
] | |
elif not isinstance(control_guidance_end, list) and isinstance( | |
control_guidance_start, list | |
): | |
control_guidance_end = len(control_guidance_start) * [control_guidance_end] | |
elif not isinstance(control_guidance_start, list) and not isinstance( | |
control_guidance_end, list | |
): | |
mult = ( | |
len(controlnet.nets) | |
if isinstance(controlnet, MultiControlNetModel) | |
else 1 | |
) | |
control_guidance_start, control_guidance_end = mult * [ | |
control_guidance_start | |
], mult * [control_guidance_end] | |
return controlnet, control_guidance_start, control_guidance_end | |
def prepare_controlnet_guess_mode(self, controlnet, guess_mode): | |
global_pool_conditions = ( | |
controlnet.config.global_pool_conditions | |
if isinstance(controlnet, ControlNetModel) | |
else controlnet.nets[0].config.global_pool_conditions | |
) | |
guess_mode = guess_mode or global_pool_conditions | |
return guess_mode | |
def prepare_controlnet_image_and_latents( | |
self, | |
controlnet, | |
width, | |
height, | |
batch_size, | |
num_videos_per_prompt, | |
device, | |
dtype, | |
controlnet_latents=None, | |
controlnet_condition_latents=None, | |
control_image=None, | |
controlnet_condition_images=None, | |
guess_mode=False, | |
do_classifier_free_guidance=False, | |
): | |
if isinstance(controlnet, ControlNetModel): | |
if controlnet_latents is not None: | |
if isinstance(controlnet_latents, np.ndarray): | |
controlnet_latents = torch.from_numpy(controlnet_latents) | |
if controlnet_condition_latents is not None: | |
if isinstance(controlnet_condition_latents, np.ndarray): | |
controlnet_condition_latents = torch.from_numpy( | |
controlnet_condition_latents | |
) | |
# TODO:使用index进行concat | |
controlnet_latents = torch.concat( | |
[controlnet_condition_latents, controlnet_latents], dim=2 | |
) | |
if not guess_mode and do_classifier_free_guidance: | |
controlnet_latents = torch.concat([controlnet_latents] * 2, dim=0) | |
controlnet_latents = rearrange( | |
controlnet_latents, "b c t h w->(b t) c h w" | |
) | |
controlnet_latents = controlnet_latents.to(device=device, dtype=dtype) | |
if self.print_idx == 0: | |
logger.debug( | |
f"call, controlnet_latents.shape, f{controlnet_latents.shape}" | |
) | |
else: | |
# TODO: concat with index | |
if isinstance(control_image, np.ndarray): | |
control_image = torch.from_numpy(control_image) | |
if controlnet_condition_images is not None: | |
if isinstance(controlnet_condition_images, np.ndarray): | |
controlnet_condition_images = torch.from_numpy( | |
controlnet_condition_images | |
) | |
control_image = torch.concatenate( | |
[controlnet_condition_images, control_image], dim=2 | |
) | |
control_image = self.prepare_control_image( | |
image=control_image, | |
width=width, | |
height=height, | |
batch_size=batch_size * num_videos_per_prompt, | |
num_images_per_prompt=num_videos_per_prompt, | |
device=device, | |
dtype=controlnet.dtype, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
guess_mode=guess_mode, | |
) | |
height, width = control_image.shape[-2:] | |
if self.print_idx == 0: | |
logger.debug(f"call, control_image.shape , {control_image.shape}") | |
elif isinstance(controlnet, MultiControlNetModel): | |
control_images = [] | |
# TODO: directly support contronet_latent instead of frames | |
if ( | |
controlnet_latents is not None | |
and controlnet_condition_latents is not None | |
): | |
raise NotImplementedError | |
for i, control_image_ in enumerate(control_image): | |
if controlnet_condition_images is not None and isinstance( | |
controlnet_condition_images, list | |
): | |
if isinstance(controlnet_condition_images[i], np.ndarray): | |
control_image_ = np.concatenate( | |
[controlnet_condition_images[i], control_image_], axis=2 | |
) | |
control_image_ = self.prepare_control_image( | |
image=control_image_, | |
width=width, | |
height=height, | |
batch_size=batch_size * num_videos_per_prompt, | |
num_images_per_prompt=num_videos_per_prompt, | |
device=device, | |
dtype=controlnet.dtype, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
guess_mode=guess_mode, | |
) | |
control_images.append(control_image_) | |
control_image = control_images | |
height, width = control_image[0].shape[-2:] | |
else: | |
assert False | |
if control_image is not None: | |
if not isinstance(control_image, list): | |
if self.print_idx == 0: | |
logger.debug(f"control_image shape is {control_image.shape}") | |
else: | |
if self.print_idx == 0: | |
logger.debug(f"control_image shape is {control_image[0].shape}") | |
return control_image, controlnet_latents | |
def get_controlnet_emb( | |
self, | |
run_controlnet, | |
guess_mode, | |
do_classifier_free_guidance, | |
latents, | |
prompt_embeds, | |
latent_model_input, | |
controlnet_keep, | |
controlnet_conditioning_scale, | |
control_image, | |
controlnet_latents, | |
i, | |
t, | |
): | |
if run_controlnet and self.pose_guider is None: | |
# controlnet(s) inference | |
if guess_mode and do_classifier_free_guidance: | |
# Infer ControlNet only for the conditional batch. | |
control_model_input = latents | |
control_model_input = self.scheduler.scale_model_input( | |
control_model_input, t | |
) | |
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] | |
else: | |
control_model_input = latent_model_input | |
controlnet_prompt_embeds = prompt_embeds | |
if isinstance(controlnet_keep[i], list): | |
cond_scale = [ | |
c * s | |
for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i]) | |
] | |
else: | |
cond_scale = controlnet_conditioning_scale * controlnet_keep[i] | |
control_model_input_reshape = rearrange( | |
control_model_input, "b c t h w -> (b t) c h w" | |
) | |
logger.debug( | |
f"control_model_input_reshape={control_model_input_reshape.shape}, controlnet_prompt_embeds={controlnet_prompt_embeds.shape}" | |
) | |
encoder_hidden_states_repeat = align_repeat_tensor_single_dim( | |
controlnet_prompt_embeds, | |
target_length=control_model_input_reshape.shape[0], | |
dim=0, | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"control_model_input_reshape={control_model_input_reshape.shape}, " | |
f"encoder_hidden_states_repeat={encoder_hidden_states_repeat.shape}, " | |
) | |
down_block_res_samples, mid_block_res_sample = self.controlnet( | |
control_model_input_reshape, | |
t, | |
encoder_hidden_states_repeat, | |
controlnet_cond=control_image, | |
controlnet_cond_latents=controlnet_latents, | |
conditioning_scale=cond_scale, | |
guess_mode=guess_mode, | |
return_dict=False, | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"controlnet, len(down_block_res_samples, {len(down_block_res_samples)}", | |
) | |
for i_tmp, tmp in enumerate(down_block_res_samples): | |
logger.debug( | |
f"controlnet down_block_res_samples i={i_tmp}, down_block_res_sample={tmp.shape}" | |
) | |
logger.debug( | |
f"controlnet mid_block_res_sample, {mid_block_res_sample.shape}" | |
) | |
if guess_mode and do_classifier_free_guidance: | |
# Infered ControlNet only for the conditional batch. | |
# To apply the output of ControlNet to both the unconditional and conditional batches, | |
# add 0 to the unconditional batch to keep it unchanged. | |
down_block_res_samples = [ | |
torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples | |
] | |
mid_block_res_sample = torch.cat( | |
[ | |
torch.zeros_like(mid_block_res_sample), | |
mid_block_res_sample, | |
] | |
) | |
else: | |
down_block_res_samples = None | |
mid_block_res_sample = None | |
return down_block_res_samples, mid_block_res_sample | |
def __call__( | |
self, | |
video_length: Optional[int], | |
prompt: Union[str, List[str]] = None, | |
# b c t h w | |
image: Union[ | |
torch.FloatTensor, | |
PIL.Image.Image, | |
np.ndarray, | |
List[torch.FloatTensor], | |
List[PIL.Image.Image], | |
List[np.ndarray], | |
] = None, | |
control_image: Union[ | |
torch.FloatTensor, | |
PIL.Image.Image, | |
np.ndarray, | |
List[torch.FloatTensor], | |
List[PIL.Image.Image], | |
List[np.ndarray], | |
] = None, | |
# b c t(1) ho wo | |
condition_images: Optional[torch.FloatTensor] = None, | |
condition_latents: Optional[torch.FloatTensor] = None, | |
latents: Optional[torch.FloatTensor] = None, | |
add_latents_noise: bool = False, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
strength: float = 0.8, | |
num_inference_steps: int = 50, | |
guidance_scale: float = 7.5, | |
guidance_scale_end: float = None, | |
guidance_scale_method: str = "linear", | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
num_videos_per_prompt: Optional[int] = 1, | |
eta: float = 0.0, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
# b c t(1) hi wi | |
controlnet_condition_images: Optional[torch.FloatTensor] = None, | |
# b c t(1) ho wo | |
controlnet_condition_latents: Optional[torch.FloatTensor] = None, | |
controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None, | |
prompt_embeds: Optional[torch.FloatTensor] = None, | |
negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
output_type: Optional[str] = "tensor", | |
return_dict: bool = True, | |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | |
callback_steps: int = 1, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
controlnet_conditioning_scale: Union[float, List[float]] = 1.0, | |
guess_mode: bool = False, | |
control_guidance_start: Union[float, List[float]] = 0.0, | |
control_guidance_end: Union[float, List[float]] = 1.0, | |
need_middle_latents: bool = False, | |
w_ind_noise: float = 0.5, | |
initial_common_latent: Optional[torch.FloatTensor] = None, | |
latent_index: torch.LongTensor = None, | |
vision_condition_latent_index: torch.LongTensor = None, | |
# noise parameters | |
noise_type: str = "random", | |
need_img_based_video_noise: bool = False, | |
skip_temporal_layer: bool = False, | |
img_weight: float = 1e-3, | |
need_hist_match: bool = False, | |
motion_speed: float = 8.0, | |
refer_image: Optional[Tuple[torch.Tensor, np.array]] = None, | |
ip_adapter_image: Optional[Tuple[torch.Tensor, np.array]] = None, | |
refer_face_image: Optional[Tuple[torch.Tensor, np.array]] = None, | |
ip_adapter_scale: float = 1.0, | |
facein_scale: float = 1.0, | |
ip_adapter_face_scale: float = 1.0, | |
ip_adapter_face_image: Optional[Tuple[torch.Tensor, np.array]] = None, | |
prompt_only_use_image_prompt: bool = False, | |
# serial_denoise parameter start | |
record_mid_video_noises: bool = False, | |
last_mid_video_noises: List[torch.Tensor] = None, | |
record_mid_video_latents: bool = False, | |
last_mid_video_latents: List[torch.TensorType] = None, | |
video_overlap: int = 1, | |
# serial_denoise parameter end | |
# parallel_denoise parameter start | |
# refer to https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/src/pipelines/pipeline_pose2vid_long.py#L354 | |
context_schedule="uniform", | |
context_frames=12, | |
context_stride=1, | |
context_overlap=4, | |
context_batch_size=1, | |
interpolation_factor=1, | |
# parallel_denoise parameter end | |
): | |
r""" | |
旨在兼容text2video、text2image、img2img、video2video、是否有controlnet等的通用pipeline。目前仅不支持img2img、video2video。 | |
支持多片段同时denoise,交叉部分加权平均 | |
当 skip_temporal_layer 为 False 时, unet 起 video 生成作用;skip_temporal_layer为True时,unet起原image作用。 | |
当controlnet的所有入参为None,等价于走的是text2video pipeline; | |
当 condition_latents、controlnet_condition_images、controlnet_condition_latents为None时,表示不走首帧条件生成的时序condition pipeline | |
现在没有考虑对 `num_videos_per_prompt` 的兼容性,不是1可能报错; | |
if skip_temporal_layer is False, unet motion layer works, else unet only run text2image layers. | |
if parameters about controlnet are None, means text2video pipeline; | |
if ondition_latents、controlnet_condition_images、controlnet_condition_latents are None, means only run text2video without vision condition images. | |
By now, code works well with `num_videos_per_prpmpt=1`, !=1 may be wrong. | |
Args: | |
prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. | |
instead. | |
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: | |
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): | |
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If | |
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can | |
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If | |
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are | |
specified in init, images must be passed as a list such that each element of the list can be correctly | |
batched for input to a single controlnet. | |
condition_latents: | |
与latents相对应,是Latents的时序condition,一般为首帧,b c t(1) ho wo | |
be corresponding to latents, vision condtion latents, usually first frame, should be b c t(1) ho wo. | |
controlnet_latents: | |
与image二选一,image会被转化成controlnet_latents | |
Choose either image or controlnet_latents. If image is chosen, it will be converted to controlnet_latents. | |
controlnet_condition_images: | |
Optional[torch.FloatTensor]# b c t(1) ho wo,与image相对应,会和image在t通道concat一起,然后转化成 controlnet_latents | |
b c t(1) ho wo, corresponding to image, will be concatenated along the t channel with image and then converted to controlnet_latents. | |
controlnet_condition_latents: Optional[torch.FloatTensor]:# | |
b c t(1) ho wo,会和 controlnet_latents 在t 通道concat一起,转化成 controlnet_latents | |
b c t(1) ho wo will be concatenated along the t channel with controlnet_latents and converted to controlnet_latents. | |
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | |
The height in pixels of the generated image. | |
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | |
The width in pixels of the generated image. | |
num_inference_steps (`int`, *optional*, defaults to 50): | |
The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
expense of slower inference. | |
guidance_scale (`float`, *optional*, defaults to 7.5): | |
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | |
`guidance_scale` is defined as `w` of equation 2. of [Imagen | |
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | |
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | |
usually at the expense of lower image quality. | |
negative_prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts not to guide the image generation. If not defined, one has to pass | |
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | |
less than `1`). | |
strength (`float`, *optional*, defaults to 0.8): | |
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a | |
starting point and more noise is added the higher the `strength`. The number of denoising steps depends | |
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising | |
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 | |
essentially ignores `image`. | |
num_images_per_prompt (`int`, *optional*, defaults to 1): | |
The number of images to generate per prompt. | |
eta (`float`, *optional*, defaults to 0.0): | |
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to | |
[`schedulers.DDIMScheduler`], will be ignored for others. | |
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | |
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) | |
to make generation deterministic. | |
latents (`torch.FloatTensor`, *optional*): | |
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | |
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | |
tensor will ge generated by sampling using the supplied random `generator`. | |
prompt_embeds (`torch.FloatTensor`, *optional*): | |
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
provided, text embeddings will be generated from `prompt` input argument. | |
negative_prompt_embeds (`torch.FloatTensor`, *optional*): | |
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | |
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | |
argument. | |
output_type (`str`, *optional*, defaults to `"pil"`): | |
The output format of the generate image. Choose between | |
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | |
plain tuple. | |
callback (`Callable`, *optional*): | |
A function that will be called every `callback_steps` steps during inference. The function will be | |
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. | |
callback_steps (`int`, *optional*, defaults to 1): | |
The frequency at which the `callback` function will be called. If not specified, the callback will be | |
called at every step. | |
cross_attention_kwargs (`dict`, *optional*): | |
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
`self.processor` in | |
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). | |
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): | |
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added | |
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the | |
corresponding scale as a list. | |
guess_mode (`bool`, *optional*, defaults to `False`): | |
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if | |
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. | |
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): | |
The percentage of total steps at which the controlnet starts applying. | |
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): | |
The percentage of total steps at which the controlnet stops applying. | |
skip_temporal_layer (`bool`: default to False) 为False时,unet起video生成作用,会运行时序生成的block;skip_temporal_layer为True时,unet起原image作用,跳过时序生成的block。 | |
need_img_based_video_noise: bool = False, 当只有首帧latents时,是否需要扩展为video noise; | |
num_videos_per_prompt: now only support 1. | |
Examples: | |
Returns: | |
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | |
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. | |
When returning a tuple, the first element is a list with the generated images, and the second element is a | |
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | |
(nsfw) content, according to the `safety_checker`. | |
""" | |
run_controlnet = control_image is not None or controlnet_latents is not None | |
if run_controlnet: | |
( | |
controlnet, | |
control_guidance_start, | |
control_guidance_end, | |
) = self.prepare_controlnet_and_guidance_parameter( | |
control_guidance_start=control_guidance_start, | |
control_guidance_end=control_guidance_end, | |
) | |
# 1. Check inputs. Raise error if not correct | |
self.check_inputs( | |
prompt, | |
control_image, | |
callback_steps, | |
negative_prompt, | |
prompt_embeds, | |
negative_prompt_embeds, | |
controlnet_conditioning_scale, | |
control_guidance_start, | |
control_guidance_end, | |
) | |
# 2. Define call parameters | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
device = self._execution_device | |
dtype = self.unet.dtype | |
# print("pipeline unet dtype", dtype) | |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
# corresponds to doing no classifier free guidance. | |
do_classifier_free_guidance = guidance_scale > 1.0 | |
if run_controlnet: | |
if isinstance(controlnet, MultiControlNetModel) and isinstance( | |
controlnet_conditioning_scale, float | |
): | |
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len( | |
controlnet.nets | |
) | |
guess_mode = self.prepare_controlnet_guess_mode( | |
controlnet=controlnet, | |
guess_mode=guess_mode, | |
) | |
# 3. Encode input prompt | |
text_encoder_lora_scale = ( | |
cross_attention_kwargs.get("scale", None) | |
if cross_attention_kwargs is not None | |
else None | |
) | |
if self.text_encoder is not None: | |
prompt_embeds = encode_weighted_prompt( | |
self, | |
prompt, | |
device, | |
num_videos_per_prompt, | |
do_classifier_free_guidance, | |
negative_prompt, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
# lora_scale=text_encoder_lora_scale, | |
) | |
logger.debug(f"use text_encoder prepare prompt_emb={prompt_embeds.shape}") | |
else: | |
prompt_embeds = None | |
if image is not None: | |
image = self.prepare_image( | |
image, | |
width=width, | |
height=height, | |
batch_size=batch_size * num_videos_per_prompt, | |
num_images_per_prompt=num_videos_per_prompt, | |
device=device, | |
dtype=dtype, | |
) | |
if self.print_idx == 0: | |
logger.debug(f"image={image.shape}") | |
if condition_images is not None: | |
condition_images = self.prepare_image( | |
condition_images, | |
width=width, | |
height=height, | |
batch_size=batch_size * num_videos_per_prompt, | |
num_images_per_prompt=num_videos_per_prompt, | |
device=device, | |
dtype=dtype, | |
) | |
if self.print_idx == 0: | |
logger.debug(f"condition_images={condition_images.shape}") | |
# 4. Prepare image | |
if run_controlnet: | |
( | |
control_image, | |
controlnet_latents, | |
) = self.prepare_controlnet_image_and_latents( | |
controlnet=controlnet, | |
width=width, | |
height=height, | |
batch_size=batch_size, | |
num_videos_per_prompt=num_videos_per_prompt, | |
device=device, | |
dtype=dtype, | |
controlnet_condition_latents=controlnet_condition_latents, | |
control_image=control_image, | |
controlnet_condition_images=controlnet_condition_images, | |
guess_mode=guess_mode, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
controlnet_latents=controlnet_latents, | |
) | |
# 5. Prepare timesteps | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
if strength and (image is not None and latents is not None): | |
if self.print_idx == 0: | |
logger.debug( | |
f"prepare timesteps, with get_timesteps strength={strength}, num_inference_steps={num_inference_steps}" | |
) | |
timesteps, num_inference_steps = self.get_timesteps( | |
num_inference_steps, strength, device | |
) | |
else: | |
if self.print_idx == 0: | |
logger.debug(f"prepare timesteps, without get_timesteps") | |
timesteps = self.scheduler.timesteps | |
latent_timestep = timesteps[:1].repeat( | |
batch_size * num_videos_per_prompt | |
) # 6. Prepare latent variables | |
( | |
condition_latents, | |
latent_index, | |
vision_condition_latent_index, | |
) = self.prepare_condition_latents_and_index( | |
condition_images=condition_images, | |
condition_latents=condition_latents, | |
video_length=video_length, | |
batch_size=batch_size, | |
dtype=dtype, | |
device=device, | |
latent_index=latent_index, | |
vision_condition_latent_index=vision_condition_latent_index, | |
) | |
if vision_condition_latent_index is None: | |
n_vision_cond = 0 | |
else: | |
n_vision_cond = vision_condition_latent_index.shape[0] | |
num_channels_latents = self.unet.config.in_channels | |
if self.print_idx == 0: | |
logger.debug(f"pipeline controlnet, start prepare latents") | |
latents = self.prepare_latents( | |
batch_size=batch_size * num_videos_per_prompt, | |
num_channels_latents=num_channels_latents, | |
video_length=video_length, | |
height=height, | |
width=width, | |
dtype=dtype, | |
device=device, | |
generator=generator, | |
latents=latents, | |
image=image, | |
timestep=latent_timestep, | |
w_ind_noise=w_ind_noise, | |
initial_common_latent=initial_common_latent, | |
noise_type=noise_type, | |
add_latents_noise=add_latents_noise, | |
need_img_based_video_noise=need_img_based_video_noise, | |
condition_latents=condition_latents, | |
img_weight=img_weight, | |
) | |
if self.print_idx == 0: | |
logger.debug(f"pipeline controlnet, finish prepare latents={latents.shape}") | |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
if noise_type == "video_fusion" and "noise_type" in set( | |
inspect.signature(self.scheduler.step).parameters.keys() | |
): | |
extra_step_kwargs["w_ind_noise"] = w_ind_noise | |
extra_step_kwargs["noise_type"] = noise_type | |
# extra_step_kwargs["noise_offset"] = noise_offset | |
# 7.1 Create tensor stating which controlnets to keep | |
if run_controlnet: | |
controlnet_keep = [] | |
for i in range(len(timesteps)): | |
keeps = [ | |
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) | |
for s, e in zip(control_guidance_start, control_guidance_end) | |
] | |
controlnet_keep.append( | |
keeps[0] if isinstance(controlnet, ControlNetModel) else keeps | |
) | |
else: | |
controlnet_keep = None | |
# 8. Denoising loop | |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
if skip_temporal_layer: | |
self.unet.set_skip_temporal_layers(True) | |
n_timesteps = len(timesteps) | |
guidance_scale_lst = generate_parameters_with_timesteps( | |
start=guidance_scale, | |
stop=guidance_scale_end, | |
num=n_timesteps, | |
method=guidance_scale_method, | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"guidance_scale_lst, {guidance_scale_method}, {guidance_scale}, {guidance_scale_end}, {guidance_scale_lst}" | |
) | |
ip_adapter_image_emb = self.get_ip_adapter_image_emb( | |
ip_adapter_image=ip_adapter_image, | |
batch_size=batch_size, | |
device=device, | |
dtype=dtype, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
height=height, | |
width=width, | |
) | |
# 当前仅当没有ip_adapter时,按照参数 prompt_only_use_image_prompt 要求是否完全替换 image_prompt_emb | |
# only if ip_adapter is None and prompt_only_use_image_prompt is True, use image_prompt_emb replace text_prompt | |
if ( | |
ip_adapter_image_emb is not None | |
and prompt_only_use_image_prompt | |
and not self.unet.ip_adapter_cross_attn | |
): | |
prompt_embeds = ip_adapter_image_emb | |
logger.debug(f"use ip_adapter_image_emb replace prompt_embeds") | |
refer_face_image_emb = self.get_facein_image_emb( | |
refer_face_image=refer_face_image, | |
batch_size=batch_size, | |
device=device, | |
dtype=dtype, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
) | |
ip_adapter_face_emb = self.get_ip_adapter_face_emb( | |
refer_face_image=ip_adapter_face_image, | |
batch_size=batch_size, | |
device=device, | |
dtype=dtype, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
) | |
refer_image_vae_emb = self.get_referencenet_image_vae_emb( | |
refer_image=refer_image, | |
device=device, | |
dtype=dtype, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
num_videos_per_prompt=num_videos_per_prompt, | |
batch_size=batch_size, | |
width=width, | |
height=height, | |
) | |
if self.pose_guider is not None and control_image is not None: | |
if self.print_idx == 0: | |
logger.debug(f"pose_guider, controlnet_image={control_image.shape}") | |
control_image = rearrange( | |
control_image, " (b t) c h w->b c t h w", t=video_length | |
) | |
pose_guider_emb = self.pose_guider(control_image) | |
pose_guider_emb = rearrange(pose_guider_emb, "b c t h w-> (b t) c h w") | |
else: | |
pose_guider_emb = None | |
logger.debug(f"prompt_embeds={prompt_embeds.shape}") | |
if control_image is not None: | |
if isinstance(control_image, list): | |
logger.debug(f"control_imageis list, num={len(control_image)}") | |
control_image = [ | |
rearrange( | |
control_image_tmp, | |
" (b t) c h w->b c t h w", | |
b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, | |
) | |
for control_image_tmp in control_image | |
] | |
else: | |
logger.debug(f"control_image={control_image.shape}, before") | |
control_image = rearrange( | |
control_image, | |
" (b t) c h w->b c t h w", | |
b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, | |
) | |
logger.debug(f"control_image={control_image.shape}, after") | |
if controlnet_latents is not None: | |
if isinstance(controlnet_latents, list): | |
logger.debug( | |
f"controlnet_latents is list, num={len(controlnet_latents)}" | |
) | |
controlnet_latents = [ | |
rearrange( | |
controlnet_latents_tmp, | |
" (b t) c h w->b c t h w", | |
b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, | |
) | |
for controlnet_latents_tmp in controlnet_latents | |
] | |
else: | |
logger.debug(f"controlnet_latents={controlnet_latents.shape}, before") | |
controlnet_latents = rearrange( | |
controlnet_latents, | |
" (b t) c h w->b c t h w", | |
b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, | |
) | |
logger.debug(f"controlnet_latents={controlnet_latents.shape}, after") | |
videos_mid = [] | |
mid_video_noises = [] if record_mid_video_noises else None | |
mid_video_latents = [] if record_mid_video_latents else None | |
global_context = prepare_global_context( | |
context_schedule=context_schedule, | |
num_inference_steps=num_inference_steps, | |
time_size=latents.shape[2], | |
context_frames=context_frames, | |
context_stride=context_stride, | |
context_overlap=context_overlap, | |
context_batch_size=context_batch_size, | |
) | |
logger.debug( | |
f"context_schedule={context_schedule}, time_size={latents.shape[2]}, context_frames={context_frames}, context_stride={context_stride}, context_overlap={context_overlap}, context_batch_size={context_batch_size}" | |
) | |
logger.debug(f"global_context={global_context}") | |
# iterative denoise | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
for i, t in enumerate(timesteps): | |
# 使用 last_mid_video_latents 来影响初始化latent,该部分效果较差,暂留代码 | |
# use last_mide_video_latents to affect initial latent. works bad, Temporarily reserved | |
if i == 0: | |
if record_mid_video_latents: | |
mid_video_latents.append(latents[:, :, -video_overlap:]) | |
if record_mid_video_noises: | |
mid_video_noises.append(None) | |
if ( | |
last_mid_video_latents is not None | |
and len(last_mid_video_latents) > 0 | |
): | |
if self.print_idx == 1: | |
logger.debug( | |
f"{i}, last_mid_video_latents={last_mid_video_latents[i].shape}" | |
) | |
latents = fuse_part_tensor( | |
last_mid_video_latents[0], | |
latents, | |
video_overlap, | |
weight=0.1, | |
skip_step=0, | |
) | |
noise_pred = torch.zeros( | |
( | |
latents.shape[0] * (2 if do_classifier_free_guidance else 1), | |
*latents.shape[1:], | |
), | |
device=latents.device, | |
dtype=latents.dtype, | |
) | |
counter = torch.zeros( | |
(1, 1, latents.shape[2], 1, 1), | |
device=latents.device, | |
dtype=latents.dtype, | |
) | |
if i == 0: | |
( | |
down_block_refer_embs, | |
mid_block_refer_emb, | |
refer_self_attn_emb, | |
) = self.get_referencenet_emb( | |
refer_image_vae_emb=refer_image_vae_emb, | |
refer_image=refer_image, | |
device=device, | |
dtype=dtype, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
num_videos_per_prompt=num_videos_per_prompt, | |
prompt_embeds=prompt_embeds, | |
ip_adapter_image_emb=ip_adapter_image_emb, | |
batch_size=batch_size, | |
ref_timestep_int=t, | |
) | |
for context in global_context: | |
# expand the latents if we are doing classifier free guidance | |
latents_c = torch.cat([latents[:, :, c] for c in context]) | |
latent_index_c = ( | |
torch.cat([latent_index[c] for c in context]) | |
if latent_index is not None | |
else None | |
) | |
latent_model_input = latents_c.to(device).repeat( | |
2 if do_classifier_free_guidance else 1, 1, 1, 1, 1 | |
) | |
latent_model_input = self.scheduler.scale_model_input( | |
latent_model_input, t | |
) | |
sub_latent_index_c = ( | |
torch.LongTensor( | |
torch.arange(latent_index_c.shape[-1]) + n_vision_cond | |
).to(device=latents_c.device) | |
if latent_index is not None | |
else None | |
) | |
if condition_latents is not None: | |
latent_model_condition = ( | |
torch.cat([condition_latents] * 2) | |
if do_classifier_free_guidance | |
else latents | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"vision_condition_latent_index, {vision_condition_latent_index.shape}, vision_condition_latent_index" | |
) | |
logger.debug( | |
f"latent_model_condition, {latent_model_condition.shape}" | |
) | |
logger.debug(f"latent_index, {latent_index_c.shape}") | |
logger.debug( | |
f"latent_model_input, {latent_model_input.shape}" | |
) | |
logger.debug(f"sub_latent_index_c, {sub_latent_index_c}") | |
latent_model_input = batch_concat_two_tensor_with_index( | |
data1=latent_model_condition, | |
data1_index=vision_condition_latent_index, | |
data2=latent_model_input, | |
data2_index=sub_latent_index_c, | |
dim=2, | |
) | |
if control_image is not None: | |
if vision_condition_latent_index is not None: | |
# 获取 vision_condition 对应的 control_imgae/control_latent 部分 | |
# generate control_image/control_latent corresponding to vision_condition | |
controlnet_condtion_latent_index = ( | |
vision_condition_latent_index.clone().cpu().tolist() | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"context={context}, controlnet_condtion_latent_index={controlnet_condtion_latent_index}" | |
) | |
controlnet_context = [ | |
controlnet_condtion_latent_index | |
+ [c_i + n_vision_cond for c_i in c] | |
for c in context | |
] | |
else: | |
controlnet_context = context | |
if self.print_idx == 0: | |
logger.debug( | |
f"controlnet_context={controlnet_context}, latent_model_input={latent_model_input.shape}" | |
) | |
if isinstance(control_image, list): | |
control_image_c = [ | |
torch.cat( | |
[ | |
control_image_tmp[:, :, c] | |
for c in controlnet_context | |
] | |
) | |
for control_image_tmp in control_image | |
] | |
control_image_c = [ | |
rearrange(control_image_tmp, " b c t h w-> (b t) c h w") | |
for control_image_tmp in control_image_c | |
] | |
else: | |
control_image_c = torch.cat( | |
[control_image[:, :, c] for c in controlnet_context] | |
) | |
control_image_c = rearrange( | |
control_image_c, " b c t h w-> (b t) c h w" | |
) | |
else: | |
control_image_c = None | |
if controlnet_latents is not None: | |
if vision_condition_latent_index is not None: | |
# 获取 vision_condition 对应的 control_imgae/control_latent 部分 | |
# generate control_image/control_latent corresponding to vision_condition | |
controlnet_condtion_latent_index = ( | |
vision_condition_latent_index.clone().cpu().tolist() | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"context={context}, controlnet_condtion_latent_index={controlnet_condtion_latent_index}" | |
) | |
controlnet_context = [ | |
controlnet_condtion_latent_index | |
+ [c_i + n_vision_cond for c_i in c] | |
for c in context | |
] | |
else: | |
controlnet_context = context | |
if self.print_idx == 0: | |
logger.debug( | |
f"controlnet_context={controlnet_context}, controlnet_latents={controlnet_latents.shape}, latent_model_input={latent_model_input.shape}," | |
) | |
controlnet_latents_c = torch.cat( | |
[controlnet_latents[:, :, c] for c in controlnet_context] | |
) | |
controlnet_latents_c = rearrange( | |
controlnet_latents_c, " b c t h w-> (b t) c h w" | |
) | |
else: | |
controlnet_latents_c = None | |
( | |
down_block_res_samples, | |
mid_block_res_sample, | |
) = self.get_controlnet_emb( | |
run_controlnet=run_controlnet, | |
guess_mode=guess_mode, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
latents=latents_c, | |
prompt_embeds=prompt_embeds, | |
latent_model_input=latent_model_input, | |
control_image=control_image_c, | |
controlnet_latents=controlnet_latents_c, | |
controlnet_keep=controlnet_keep, | |
t=t, | |
i=i, | |
controlnet_conditioning_scale=controlnet_conditioning_scale, | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"{i}, latent_model_input={latent_model_input.shape}, sub_latent_index_c={sub_latent_index_c}" | |
f"{vision_condition_latent_index}" | |
) | |
# time.sleep(10) | |
noise_pred_c = self.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=prompt_embeds, | |
cross_attention_kwargs=cross_attention_kwargs, | |
down_block_additional_residuals=down_block_res_samples, | |
mid_block_additional_residual=mid_block_res_sample, | |
return_dict=False, | |
sample_index=sub_latent_index_c, | |
vision_conditon_frames_sample_index=vision_condition_latent_index, | |
sample_frame_rate=motion_speed, | |
down_block_refer_embs=down_block_refer_embs, | |
mid_block_refer_emb=mid_block_refer_emb, | |
refer_self_attn_emb=refer_self_attn_emb, | |
vision_clip_emb=ip_adapter_image_emb, | |
face_emb=refer_face_image_emb, | |
ip_adapter_scale=ip_adapter_scale, | |
facein_scale=facein_scale, | |
ip_adapter_face_emb=ip_adapter_face_emb, | |
ip_adapter_face_scale=ip_adapter_face_scale, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
pose_guider_emb=pose_guider_emb, | |
)[0] | |
if condition_latents is not None: | |
noise_pred_c = batch_index_select( | |
noise_pred_c, dim=2, index=sub_latent_index_c | |
).contiguous() | |
if self.print_idx == 0: | |
logger.debug( | |
f"{i}, latent_model_input={latent_model_input.shape}, noise_pred_c={noise_pred_c.shape}, {len(context)}, {len(context[0])}" | |
) | |
for j, c in enumerate(context): | |
noise_pred[:, :, c] = noise_pred[:, :, c] + noise_pred_c | |
counter[:, :, c] = counter[:, :, c] + 1 | |
noise_pred = noise_pred / counter | |
if ( | |
last_mid_video_noises is not None | |
and len(last_mid_video_noises) > 0 | |
and i <= num_inference_steps // 2 # 是个超参数 super paramter | |
): | |
if self.print_idx == 1: | |
logger.debug( | |
f"{i}, last_mid_video_noises={last_mid_video_noises[i].shape}" | |
) | |
noise_pred = fuse_part_tensor( | |
last_mid_video_noises[i + 1], | |
noise_pred, | |
video_overlap, | |
weight=0.01, | |
skip_step=1, | |
) | |
if record_mid_video_noises: | |
mid_video_noises.append(noise_pred[:, :, -video_overlap:]) | |
# perform guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale_lst[i] * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"before step, noise_pred={noise_pred.shape}, {noise_pred.device}, latents={latents.shape}, {latents.device}, t={t}" | |
) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = self.scheduler.step( | |
noise_pred, | |
t, | |
latents, | |
**extra_step_kwargs, | |
).prev_sample | |
if ( | |
last_mid_video_latents is not None | |
and len(last_mid_video_latents) > 0 | |
and i <= 1 # 超参数, super parameter | |
): | |
if self.print_idx == 1: | |
logger.debug( | |
f"{i}, last_mid_video_latents={last_mid_video_latents[i].shape}" | |
) | |
latents = fuse_part_tensor( | |
last_mid_video_latents[i + 1], | |
latents, | |
video_overlap, | |
weight=0.1, | |
skip_step=0, | |
) | |
if record_mid_video_latents: | |
mid_video_latents.append(latents[:, :, -video_overlap:]) | |
if need_middle_latents is True: | |
videos_mid.append(self.decode_latents(latents)) | |
# call the callback, if provided | |
if i == len(timesteps) - 1 or ( | |
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 | |
): | |
progress_bar.update() | |
if callback is not None and i % callback_steps == 0: | |
callback(i, t, latents) | |
self.print_idx += 1 | |
if condition_latents is not None: | |
latents = batch_concat_two_tensor_with_index( | |
data1=condition_latents, | |
data1_index=vision_condition_latent_index, | |
data2=latents, | |
data2_index=latent_index, | |
dim=2, | |
) | |
video = self.decode_latents(latents) | |
if skip_temporal_layer: | |
self.unet.set_skip_temporal_layers(False) | |
if need_hist_match: | |
video[:, :, latent_index, :, :] = self.hist_match_with_vis_cond( | |
batch_index_select(video, index=latent_index, dim=2), | |
batch_index_select(video, index=vision_condition_latent_index, dim=2), | |
) | |
# Convert to tensor | |
if output_type == "tensor": | |
videos_mid = [torch.from_numpy(x) for x in videos_mid] | |
video = torch.from_numpy(video) | |
else: | |
latents = latents.cpu().numpy() | |
if not return_dict: | |
return ( | |
video, | |
latents, | |
videos_mid, | |
mid_video_latents, | |
mid_video_noises, | |
) | |
return VideoPipelineOutput( | |
videos=video, | |
latents=latents, | |
videos_mid=videos_mid, | |
mid_video_latents=mid_video_latents, | |
mid_video_noises=mid_video_noises, | |
) | |