zejunyang
commited on
Commit
•
c7a4aba
1
Parent(s):
2de857a
init
Browse files- src/create_modules.py +96 -0
src/create_modules.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import ffmpeg
|
3 |
+
from datetime import datetime
|
4 |
+
from pathlib import Path
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
import torch
|
8 |
+
from scipy.spatial.transform import Rotation as R
|
9 |
+
from scipy.interpolate import interp1d
|
10 |
+
|
11 |
+
from diffusers import AutoencoderKL, DDIMScheduler
|
12 |
+
from einops import repeat
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
from PIL import Image
|
15 |
+
from torchvision import transforms
|
16 |
+
from transformers import CLIPVisionModelWithProjection
|
17 |
+
|
18 |
+
|
19 |
+
from src.models.pose_guider import PoseGuider
|
20 |
+
from src.models.unet_2d_condition import UNet2DConditionModel
|
21 |
+
from src.models.unet_3d import UNet3DConditionModel
|
22 |
+
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
|
23 |
+
from src.utils.util import save_videos_grid
|
24 |
+
|
25 |
+
from src.audio_models.model import Audio2MeshModel
|
26 |
+
from src.utils.audio_util import prepare_audio_feature
|
27 |
+
from src.utils.mp_utils import LMKExtractor
|
28 |
+
from src.utils.draw_util import FaceMeshVisualizer
|
29 |
+
from src.utils.pose_util import project_points
|
30 |
+
|
31 |
+
|
32 |
+
lmk_extractor = LMKExtractor()
|
33 |
+
vis = FaceMeshVisualizer(forehead_edge=False)
|
34 |
+
|
35 |
+
config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
|
36 |
+
|
37 |
+
if config.weight_dtype == "fp16":
|
38 |
+
weight_dtype = torch.float16
|
39 |
+
else:
|
40 |
+
weight_dtype = torch.float32
|
41 |
+
|
42 |
+
audio_infer_config = OmegaConf.load(config.audio_inference_config)
|
43 |
+
# prepare model
|
44 |
+
a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
|
45 |
+
a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False)
|
46 |
+
a2m_model.cuda().eval()
|
47 |
+
|
48 |
+
vae = AutoencoderKL.from_pretrained(
|
49 |
+
config.pretrained_vae_path,
|
50 |
+
).to("cuda", dtype=weight_dtype)
|
51 |
+
|
52 |
+
reference_unet = UNet2DConditionModel.from_pretrained(
|
53 |
+
config.pretrained_base_model_path,
|
54 |
+
subfolder="unet",
|
55 |
+
).to(dtype=weight_dtype, device="cuda")
|
56 |
+
|
57 |
+
inference_config_path = config.inference_config
|
58 |
+
infer_config = OmegaConf.load(inference_config_path)
|
59 |
+
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
60 |
+
config.pretrained_base_model_path,
|
61 |
+
config.motion_module_path,
|
62 |
+
subfolder="unet",
|
63 |
+
unet_additional_kwargs=infer_config.unet_additional_kwargs,
|
64 |
+
).to(dtype=weight_dtype, device="cuda")
|
65 |
+
|
66 |
+
|
67 |
+
pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
|
68 |
+
|
69 |
+
image_enc = CLIPVisionModelWithProjection.from_pretrained(
|
70 |
+
config.image_encoder_path
|
71 |
+
).to(dtype=weight_dtype, device="cuda")
|
72 |
+
|
73 |
+
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
|
74 |
+
scheduler = DDIMScheduler(**sched_kwargs)
|
75 |
+
|
76 |
+
# load pretrained weights
|
77 |
+
denoising_unet.load_state_dict(
|
78 |
+
torch.load(config.denoising_unet_path, map_location="cpu"),
|
79 |
+
strict=False,
|
80 |
+
)
|
81 |
+
reference_unet.load_state_dict(
|
82 |
+
torch.load(config.reference_unet_path, map_location="cpu"),
|
83 |
+
)
|
84 |
+
pose_guider.load_state_dict(
|
85 |
+
torch.load(config.pose_guider_path, map_location="cpu"),
|
86 |
+
)
|
87 |
+
|
88 |
+
pipe = Pose2VideoPipeline(
|
89 |
+
vae=vae,
|
90 |
+
image_encoder=image_enc,
|
91 |
+
reference_unet=reference_unet,
|
92 |
+
denoising_unet=denoising_unet,
|
93 |
+
pose_guider=pose_guider,
|
94 |
+
scheduler=scheduler,
|
95 |
+
)
|
96 |
+
pipe = pipe.to("cuda", dtype=weight_dtype)
|