Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from typing import List | |
import torch | |
from diffusers import StableDiffusionPipeline | |
from diffusers.pipelines.controlnet import MultiControlNetModel | |
from PIL import Image | |
from safetensors import safe_open | |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
from .attention_processor import LoRAFaceAttnProcessor | |
from .utils import is_torch2_available, get_generator | |
if is_torch2_available(): | |
from .attention_processor import ( | |
AttnProcessor2_0 as AttnProcessor, | |
) | |
else: | |
from .attention_processor import AttnProcessor | |
from .resampler import Resampler | |
class ImageProjModel(torch.nn.Module): | |
"""Projection Model""" | |
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): | |
super().__init__() | |
self.generator = None | |
self.cross_attention_dim = cross_attention_dim | |
self.clip_extra_context_tokens = clip_extra_context_tokens | |
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) | |
self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
def forward(self, image_embeds): | |
embeds = image_embeds | |
clip_extra_context_tokens = self.proj(embeds).reshape( | |
-1, self.clip_extra_context_tokens, self.cross_attention_dim | |
) | |
clip_extra_context_tokens = self.norm(clip_extra_context_tokens) | |
return clip_extra_context_tokens | |
class MLPProjModel(torch.nn.Module): | |
"""SD model with image prompt""" | |
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): | |
super().__init__() | |
self.proj = torch.nn.Sequential( | |
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), | |
torch.nn.GELU(), | |
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), | |
torch.nn.LayerNorm(cross_attention_dim) | |
) | |
def forward(self, image_embeds): | |
clip_extra_context_tokens = self.proj(image_embeds) | |
return clip_extra_context_tokens | |
class FaceAdapterLora: | |
def __init__(self, sd_pipe, image_encoder_path, id_ckpt, device, num_tokens=4,torch_type=torch.float32): | |
self.device = device | |
self.image_encoder_path = image_encoder_path | |
self.id_ckpt = id_ckpt | |
self.num_tokens = num_tokens | |
self.torch_type = torch_type | |
self.pipe = sd_pipe.to(self.device) | |
self.set_face_adapter() | |
# load image encoder | |
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( | |
self.device, dtype=self.torch_type | |
) | |
self.clip_image_processor = CLIPImageProcessor() | |
# image proj model | |
self.image_proj_model = self.init_proj() | |
self.load_face_adapter() | |
def init_proj(self): | |
image_proj_model = ImageProjModel( | |
cross_attention_dim=self.pipe.unet.config.cross_attention_dim, | |
clip_embeddings_dim=self.image_encoder.config.projection_dim, | |
clip_extra_context_tokens=self.num_tokens, | |
).to(self.device, dtype=self.torch_type) | |
return image_proj_model | |
def set_face_adapter(self): | |
unet = self.pipe.unet | |
attn_procs = {} | |
for name in unet.attn_processors.keys(): | |
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | |
if name.startswith("mid_block"): | |
hidden_size = unet.config.block_out_channels[-1] | |
elif name.startswith("up_blocks"): | |
block_id = int(name[len("up_blocks.")]) | |
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | |
elif name.startswith("down_blocks"): | |
block_id = int(name[len("down_blocks.")]) | |
hidden_size = unet.config.block_out_channels[block_id] | |
if cross_attention_dim is None: | |
attn_procs[name] = AttnProcessor().to(self.device, dtype=self.torch_type) | |
else: | |
attn_procs[name] = LoRAFaceAttnProcessor( | |
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=128, num_tokens=self.num_tokens, | |
).to(self.device, dtype=self.torch_type) | |
unet.set_attn_processor(attn_procs) | |
def load_face_adapter(self): | |
state_dict = torch.load(self.id_ckpt, map_location="cpu") | |
if 'state_dict' in state_dict: | |
state_dict = state_dict['state_dict'] | |
image_proj_dict={} | |
face_adapter_proj={} | |
for k,v in state_dict.items(): | |
if k.startswith("module.image_proj_model"): | |
image_proj_dict[k.replace("module.image_proj_model.", "")] = state_dict[k] | |
elif k.startswith("module.adapter_modules."): | |
face_adapter_proj[k.replace("module.adapter_modules.", "")] = state_dict[k] | |
elif k.startswith("image_proj_model"): | |
image_proj_dict[k.replace("image_proj_model.", "")] = state_dict[k] | |
elif k.startswith("adapter_modules."): | |
face_adapter_proj[k.replace("adapter_modules.", "")] = state_dict[k] | |
else: | |
print("ERROR!") | |
return | |
state_dict = {} | |
state_dict['image_proj'] = image_proj_dict | |
state_dict["face_adapter"] = face_adapter_proj | |
self.image_proj_model.load_state_dict(state_dict["image_proj"]) | |
adapter_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) | |
adapter_layers.load_state_dict(state_dict["face_adapter"],strict=False) | |
def get_image_embeds(self, pil_image=None, clip_image_embeds=None): | |
if pil_image is not None: | |
if isinstance(pil_image, Image.Image): | |
pil_image = [pil_image] | |
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values | |
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=self.torch_type)).image_embeds | |
else: | |
clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.torch_type) | |
image_prompt_embeds = self.image_proj_model(clip_image_embeds) | |
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) | |
return image_prompt_embeds, uncond_image_prompt_embeds | |
# This scales the face-adapter face_hidden_states (attn output). attn_processor.scale: default 1.0. | |
# faceadapter/attention_processor.py:L283. | |
def set_attn_scale(self, attn_scale): | |
for attn_processor in self.pipe.unet.attn_processors.values(): | |
if isinstance(attn_processor, LoRAFaceAttnProcessor): | |
attn_processor.scale = attn_scale | |
def generate( | |
self, | |
pil_image=None, | |
clip_image_embeds=None, | |
prompt=None, | |
negative_prompt=None, | |
attn_scale=1, | |
num_samples=4, | |
seed=None, | |
guidance_scale=7.5, | |
num_inference_steps=30, | |
**kwargs, | |
): | |
self.set_attn_scale(attn_scale) | |
if pil_image is not None: | |
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) | |
else: | |
num_prompts = clip_image_embeds.size(0) | |
if prompt is None: | |
prompt = "best quality, high quality" | |
if negative_prompt is None: | |
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" | |
if not isinstance(prompt, List): | |
prompt = [prompt] * num_prompts | |
if not isinstance(negative_prompt, List): | |
negative_prompt = [negative_prompt] * num_prompts | |
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( | |
pil_image=pil_image, clip_image_embeds=clip_image_embeds | |
) | |
bs_embed, seq_len, _ = image_prompt_embeds.shape | |
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) | |
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) | |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) | |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) | |
with torch.inference_mode(): | |
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( | |
prompt, | |
device=self.device, | |
num_images_per_prompt=num_samples, | |
do_classifier_free_guidance=True, | |
negative_prompt=negative_prompt, | |
) | |
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) | |
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) | |
generator = get_generator(seed, self.device) | |
images = self.pipe( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
**kwargs, | |
).images | |
return images | |
class FaceAdapterPlusForVideoLora(FaceAdapterLora): | |
def init_proj(self): | |
image_proj_model = Resampler( | |
dim=self.pipe.unet.config.cross_attention_dim, | |
depth=4, | |
dim_head=64, | |
heads=12, | |
num_queries=self.num_tokens, | |
embedding_dim=self.image_encoder.config.hidden_size, | |
output_dim=self.pipe.unet.config.cross_attention_dim, | |
ff_mult=4, | |
).to(self.device, dtype=self.torch_type) | |
return image_proj_model | |
def get_image_embeds(self, pil_image=None, clip_image_embeds=None): | |
if isinstance(pil_image, Image.Image): | |
pil_image = [pil_image] | |
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values | |
clip_image = clip_image.to(self.device, dtype=self.torch_type) | |
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] | |
image_prompt_embeds = self.image_proj_model(clip_image_embeds) | |
uncond_clip_image_embeds = self.image_encoder( | |
torch.zeros_like(clip_image), output_hidden_states=True | |
).hidden_states[-2] | |
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) | |
return image_prompt_embeds, uncond_image_prompt_embeds | |
def generate( | |
self, | |
pil_image=None, | |
init_image=None, | |
init_image_strength=1., | |
clip_image_embeds=None, | |
prompt=None, | |
negative_prompt=None, | |
adaface_embeds=None, | |
adaface_scale=1.0, | |
attn_scale=1.0, | |
num_samples=1, | |
seed=None, | |
guidance_scale=4, | |
num_inference_steps=30, | |
adaface_anneal_steps=0, | |
width=512, | |
height=512, | |
video_length=16, | |
image_embed_scale=1, | |
controlnet_images: torch.FloatTensor = None, | |
controlnet_image_index: list = [0], | |
**kwargs, | |
): | |
self.set_attn_scale(attn_scale) | |
num_prompts=1 | |
if prompt is None: | |
prompt = "best quality, high quality" | |
if negative_prompt is None: | |
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" | |
if not isinstance(prompt, List): | |
prompt = [prompt] * num_prompts | |
if not isinstance(negative_prompt, List): | |
negative_prompt = [negative_prompt] * num_prompts | |
num_prompt_img = len(pil_image) | |
total_image_prompt_embeds = 0 | |
for i in range(num_prompt_img): | |
prompt_img = pil_image[i] | |
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( | |
pil_image=prompt_img, clip_image_embeds=clip_image_embeds | |
) | |
bs_embed, seq_len, _ = image_prompt_embeds.shape | |
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) | |
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) | |
total_image_prompt_embeds += image_prompt_embeds | |
total_image_prompt_embeds /= num_prompt_img | |
image_prompt_embeds = total_image_prompt_embeds | |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) | |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) | |
with torch.inference_mode(): | |
# if do_classifier_free_guidance, | |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method. | |
# https://github.com/huggingface/diffusers/blob/70f8d4b488f03730ae3bc11d4d707bafe153d10d/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L469 | |
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( | |
prompt, | |
device=self.device, | |
num_videos_per_prompt=num_samples, | |
do_classifier_free_guidance=True, | |
negative_prompt=negative_prompt, | |
) | |
if adaface_embeds is not None: | |
prompt_embeds0_ = prompt_embeds_ | |
# self.torch_type == torch.float16. adaface_embeds is torch.float32. | |
prompt_embeds_ = adaface_embeds.repeat(num_samples, 1, 1).to(dtype=self.torch_type) * adaface_scale | |
# Scale down ID-Animator's face embeddings, so that they don't dominate the generation. | |
# Note to balance image_prompt_embeds with uncond_image_prompt_embeds after scaling. | |
image_prompt_embeds = image_prompt_embeds * image_embed_scale + uncond_image_prompt_embeds * (1 - image_embed_scale) | |
# We still need uncond_image_prompt_embeds, otherwise the output is blank. | |
prompt_embeds_end = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) | |
prompt_embeds_begin = torch.cat([prompt_embeds0_, torch.zeros_like(image_prompt_embeds)], dim=1) | |
prompt_embeds = (prompt_embeds_begin, prompt_embeds_end, adaface_anneal_steps) | |
else: | |
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) | |
# prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) | |
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) | |
generator = get_generator(seed, self.device) | |
video = self.pipe( | |
init_image=init_image, | |
init_image_strength=init_image_strength, | |
prompt = "", | |
prompt_embeds = prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
width = width, | |
height=height, | |
video_length = video_length, | |
controlnet_images = controlnet_images, | |
controlnet_image_index=controlnet_image_index, | |
**kwargs, | |
).videos | |
return video | |
def generate_video_edit( | |
self, | |
pil_image=None, | |
clip_image_embeds=None, | |
prompt=None, | |
negative_prompt=None, | |
attn_scale=1.0, | |
num_samples=1, | |
seed=None, | |
guidance_scale=7.5, | |
num_inference_steps=30, | |
width=512, | |
height=512, | |
video_length=16, | |
video_latents=None, | |
**kwargs, | |
): | |
self.set_attn_scale(attn_scale) | |
if pil_image is not None: | |
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) | |
else: | |
num_prompts = clip_image_embeds.size(0) | |
if prompt is None: | |
prompt = "best quality, high quality" | |
if negative_prompt is None: | |
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" | |
if not isinstance(prompt, List): | |
prompt = [prompt] * num_prompts | |
if not isinstance(negative_prompt, List): | |
negative_prompt = [negative_prompt] * num_prompts | |
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( | |
pil_image=pil_image, clip_image_embeds=clip_image_embeds | |
) | |
bs_embed, seq_len, _ = image_prompt_embeds.shape | |
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) | |
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) | |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) | |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) | |
with torch.inference_mode(): | |
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( | |
prompt, | |
device=self.device, | |
num_videos_per_prompt=num_samples, | |
do_classifier_free_guidance=True, | |
negative_prompt=negative_prompt, | |
) | |
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) | |
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) | |
generator = get_generator(seed, self.device) | |
video = self.pipe.video_edit( | |
prompt = "", | |
prompt_embeds = prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
width = width, | |
height=height, | |
video_length = video_length, | |
latents=video_latents, | |
**kwargs, | |
).videos | |
return video |