File size: 6,731 Bytes
02cc20b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0b5a77
 
02cc20b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from diffusers import AutoencoderKL, DDIMScheduler
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from animatediff.models.unet import UNet3DConditionModel
from omegaconf import OmegaConf
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from animatediff.utils.util import load_weights
from safetensors import safe_open
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
from faceadapter.face_adapter import FaceAdapterPlusForVideoLora
from adaface.adaface_wrapper import AdaFaceWrapper

def load_adaface(base_model_path, adaface_ckpt_path, device="cuda"):
    # base_model_path is only used for initialization, not really used in the inference.
    adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path,
                             adaface_ckpt_path=adaface_ckpt_path, device=device)
    return adaface

def load_model(base_model_type="sar", adaface_base_model_type="sar", 
               adaface_ckpt_path=None, device="cuda"):
    inference_config    = "inference-v2.yaml"
    sd_version          = "animatediff/sd"
    id_ckpt             = "models/animator.ckpt"
    image_encoder_path  = "models/image_encoder"

    base_model_type_to_path = {
        "rv40": "models/realisticvision/realisticVisionV40_v40VAE.safetensors",
        "rv60": "models/realisticvision/realisticVisionV60B1_v51VAE.safetensors",
        "sd15": "models/stable-diffusion-v-1-5/v1-5-pruned.safetensors",
        "sd15_adaface": "models/stable-diffusion-v-1-5/v1-5-dste8-vae.ckpt",
        "toonyou": "models/toonyou/toonyou_beta6.safetensors",
        "epv5": "models/epic_realism/epicrealism_pureEvolutionV5.safetensors",
        "ar181": "models/absolutereality/absolutereality_v181.safetensors",
        "ar16":  "models/absolutereality/ar-v1-6.safetensors",
        "sar":   "models/sar/sar.safetensors",
    }

    base_model_path         = base_model_type_to_path[base_model_type]
    if adaface_base_model_type + "_adaface" in base_model_type_to_path:
        adaface_base_model_path = base_model_type_to_path[adaface_base_model_type + "_adaface"]
    else:
        adaface_base_model_path = base_model_type_to_path[adaface_base_model_type]

    motion_module_path="models/v3_sd15_mm.ckpt" 
    motion_lora_path = "models/v3_sd15_adapter.ckpt"
    inference_config = OmegaConf.load(inference_config)    

    tokenizer    = CLIPTokenizer.from_pretrained(sd_version, subfolder="tokenizer",torch_dtype=torch.float16,
    )
    text_encoder = CLIPTextModel.from_pretrained(sd_version, subfolder="text_encoder",torch_dtype=torch.float16,
    ).to(device=device)
    vae          = AutoencoderKL.from_pretrained(sd_version, subfolder="vae",torch_dtype=torch.float16,
    ).to(device=device)
    unet = UNet3DConditionModel.from_pretrained_2d(sd_version, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)
    ).to(device=device)
    # unet.to(dtype=torch.float16) does not work on hf spaces.
    unet = unet.half()

    pipeline = AnimationPipeline(
            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
            controlnet=None,
            #beta_start=0.00085, beta_end=0.012, beta_schedule="linear",steps_offset=1
            scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)
            # scheduler=DPMSolverMultistepScheduler(**OmegaConf.to_container(inference_config.DPMSolver_scheduler_kwargs)
            # scheduler=EulerAncestralDiscreteScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)
            # scheduler=EulerAncestralDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="linear",steps_offset=1
                                   ),
            torch_dtype=torch.float16,
            ).to(device=device)
    
    pipeline = load_weights(
            pipeline,
            # motion module
            motion_module_path         = motion_module_path,
            motion_module_lora_configs = [],
            # domain adapter
            adapter_lora_path          = motion_lora_path,
            adapter_lora_scale         = 1,
            # image layers
            dreambooth_model_path      = None,
            lora_model_path            = "",
            lora_alpha                 = 0.8
    ).to(device=device)
    
    if base_model_path != "":
        print(f"load dreambooth model from {base_model_path}")
        dreambooth_state_dict = {}
        with safe_open(base_model_path, framework="pt", device="cpu") as f:
            for key in f.keys():
                dreambooth_state_dict[key] = f.get_tensor(key)
                        
            converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config)
            # print(vae)
            #vae ->to_q,to_k,to_v
            # print(converted_vae_checkpoint)
            convert_vae_keys = list(converted_vae_checkpoint.keys())
            for key in convert_vae_keys:
                if "encoder.mid_block.attentions" in key or "decoder.mid_block.attentions" in  key:
                    new_key = None
                    if "key" in key:
                        new_key = key.replace("key","to_k")
                    elif "query" in key:
                        new_key = key.replace("query","to_q")
                    elif "value" in key:
                        new_key = key.replace("value","to_v")
                    elif "proj_attn" in key:
                        new_key = key.replace("proj_attn","to_out.0")
                    if new_key:
                        converted_vae_checkpoint[new_key] = converted_vae_checkpoint.pop(key)

            pipeline.vae.load_state_dict(converted_vae_checkpoint)

            converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config)
            pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)

            pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict, dtype=torch.float16).to(device=device)
            
        del dreambooth_state_dict
        pipeline = pipeline.to(torch.float16)
        id_animator = FaceAdapterPlusForVideoLora(pipeline, image_encoder_path, id_ckpt, num_tokens=16,
                                                  device=torch.device(device), torch_type=torch.float16)

        if adaface_ckpt_path is not None:
            adaface = load_adaface(adaface_base_model_path, #dreambooth_model_path, 
                                   adaface_ckpt_path, device)
        else:
            adaface = None

        return id_animator, adaface