adaface-animate / infer.py
adaface-neurips
Update model to consistentid, arc2face joint model
ad88a0b
raw
history blame
5.36 kB
from diffusers import AutoencoderKL, DDIMScheduler
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from animatediff.models.unet import UNet3DConditionModel
from omegaconf import OmegaConf
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from animatediff.utils.util import load_weights
from safetensors import safe_open
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
from faceadapter.face_adapter import FaceAdapterPlusForVideoLora
base_model_type_to_path = {
"sd15": "models/sd15-dste8-vae.safetensors", # LDM format. Needs to be converted.
"sar": "models/sar/sar.safetensors", # LDM format. Needs to be converted.
"rv51": "models/rv51/realisticVisionV51_v51VAE.safetensors"
}
def load_model(base_model_type="rv51", device="cuda"):
inference_config = "inference-v2.yaml"
sd_version = "animatediff/sd"
id_ckpt = "models/animator.ckpt"
image_encoder_path = "models/image_encoder"
base_model_path = base_model_type_to_path[base_model_type]
motion_module_path="models/v3_sd15_mm.ckpt"
motion_lora_path = "models/v3_sd15_adapter.ckpt"
inference_config = OmegaConf.load(inference_config)
tokenizer = CLIPTokenizer.from_pretrained(sd_version, subfolder="tokenizer",torch_dtype=torch.float16,
)
text_encoder = CLIPTextModel.from_pretrained(sd_version, subfolder="text_encoder",torch_dtype=torch.float16,
).to(device=device)
vae = AutoencoderKL.from_pretrained(sd_version, subfolder="vae",torch_dtype=torch.float16,
).to(device=device)
unet = UNet3DConditionModel.from_pretrained_2d(sd_version, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)
).to(device=device)
# unet.to(dtype=torch.float16) does not work on hf spaces.
unet = unet.half()
pipeline = AnimationPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
controlnet=None,
#beta_start=0.00085, beta_end=0.012, beta_schedule="linear",steps_offset=1
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)
# scheduler=DPMSolverMultistepScheduler(**OmegaConf.to_container(inference_config.DPMSolver_scheduler_kwargs)
# scheduler=EulerAncestralDiscreteScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)
# scheduler=EulerAncestralDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="linear",steps_offset=1
),
torch_dtype=torch.float16,
).to(device=device)
pipeline = load_weights(
pipeline,
# motion module
motion_module_path = motion_module_path,
motion_module_lora_configs = [],
# domain adapter
adapter_lora_path = motion_lora_path,
adapter_lora_scale = 1,
# image layers
dreambooth_model_path = None,
lora_model_path = "",
lora_alpha = 0.8
).to(device=device)
if base_model_path != "":
print(f"load dreambooth model from {base_model_path}")
dreambooth_state_dict = {}
with safe_open(base_model_path, framework="pt", device="cpu") as f:
for key in f.keys():
dreambooth_state_dict[key] = f.get_tensor(key)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config)
# print(vae)
#vae ->to_q,to_k,to_v
# print(converted_vae_checkpoint)
convert_vae_keys = list(converted_vae_checkpoint.keys())
for key in convert_vae_keys:
if "encoder.mid_block.attentions" in key or "decoder.mid_block.attentions" in key:
new_key = None
if "key" in key:
new_key = key.replace("key","to_k")
elif "query" in key:
new_key = key.replace("query","to_q")
elif "value" in key:
new_key = key.replace("value","to_v")
elif "proj_attn" in key:
new_key = key.replace("proj_attn","to_out.0")
if new_key:
converted_vae_checkpoint[new_key] = converted_vae_checkpoint.pop(key)
pipeline.vae.load_state_dict(converted_vae_checkpoint)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config)
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict, dtype=torch.float16).to(device=device)
del dreambooth_state_dict
pipeline = pipeline.to(torch.float16)
id_animator = FaceAdapterPlusForVideoLora(pipeline, image_encoder_path, id_ckpt, num_tokens=16,
device=torch.device(device), torch_type=torch.float16)
return id_animator