Spaces:
Running
on
Zero
Running
on
Zero
zejunyang
commited on
Commit
·
e4de730
1
Parent(s):
fab87df
debug
Browse files- src/create_modules.py +40 -35
src/create_modules.py
CHANGED
@@ -33,14 +33,11 @@ from src.utils.crop_face_single import crop_face
|
|
33 |
|
34 |
class Processer():
|
35 |
def __init__(self):
|
36 |
-
self.create_models()
|
37 |
-
|
38 |
@spaces.GPU
|
39 |
def create_models(self):
|
40 |
|
41 |
-
self.lmk_extractor = LMKExtractor()
|
42 |
-
self.vis = FaceMeshVisualizer(forehead_edge=False)
|
43 |
-
|
44 |
config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
|
45 |
|
46 |
if config.weight_dtype == "fp16":
|
@@ -50,64 +47,69 @@ class Processer():
|
|
50 |
|
51 |
audio_infer_config = OmegaConf.load(config.audio_inference_config)
|
52 |
# prepare model
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
|
57 |
-
|
58 |
config.pretrained_vae_path,
|
59 |
).to("cuda", dtype=weight_dtype)
|
60 |
|
61 |
-
|
62 |
config.pretrained_base_model_path,
|
63 |
subfolder="unet",
|
64 |
).to(dtype=weight_dtype, device="cuda")
|
65 |
|
66 |
inference_config_path = config.inference_config
|
67 |
infer_config = OmegaConf.load(inference_config_path)
|
68 |
-
|
69 |
config.pretrained_base_model_path,
|
70 |
config.motion_module_path,
|
71 |
subfolder="unet",
|
72 |
unet_additional_kwargs=infer_config.unet_additional_kwargs,
|
73 |
).to(dtype=weight_dtype, device="cuda")
|
74 |
|
75 |
-
|
76 |
|
77 |
-
|
78 |
config.image_encoder_path
|
79 |
).to(dtype=weight_dtype, device="cuda")
|
80 |
|
81 |
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
|
82 |
-
|
83 |
|
84 |
# load pretrained weights
|
85 |
-
|
86 |
torch.load(config.denoising_unet_path, map_location="cpu"),
|
87 |
strict=False,
|
88 |
)
|
89 |
-
|
90 |
torch.load(config.reference_unet_path, map_location="cpu"),
|
91 |
)
|
92 |
-
|
93 |
torch.load(config.pose_guider_path, map_location="cpu"),
|
94 |
)
|
95 |
|
96 |
-
|
97 |
-
vae=
|
98 |
-
image_encoder=
|
99 |
-
reference_unet=
|
100 |
-
denoising_unet=
|
101 |
-
pose_guider=
|
102 |
-
scheduler=
|
103 |
)
|
104 |
-
|
|
|
|
|
105 |
|
106 |
|
107 |
@spaces.GPU
|
108 |
def audio2video(self, input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42):
|
109 |
fps = 30
|
110 |
cfg = 3.5
|
|
|
|
|
|
|
111 |
|
112 |
config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
|
113 |
audio_infer_config = OmegaConf.load(config.audio_inference_config)
|
@@ -123,19 +125,19 @@ class Processer():
|
|
123 |
save_dir.mkdir(exist_ok=True, parents=True)
|
124 |
|
125 |
ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
|
126 |
-
ref_image_np = crop_face(ref_image_np,
|
127 |
if ref_image_np is None:
|
128 |
return None, Image.fromarray(ref_img)
|
129 |
|
130 |
ref_image_np = cv2.resize(ref_image_np, (size, size))
|
131 |
ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
|
132 |
|
133 |
-
face_result =
|
134 |
if face_result is None:
|
135 |
return None, ref_image_pil
|
136 |
|
137 |
lmks = face_result['lmks'].astype(np.float32)
|
138 |
-
ref_pose =
|
139 |
|
140 |
sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
|
141 |
sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
|
@@ -148,7 +150,7 @@ class Processer():
|
|
148 |
pred = pred + face_result['lmks3d']
|
149 |
|
150 |
if headpose_video is not None:
|
151 |
-
pose_seq = get_headpose_temp(headpose_video,
|
152 |
else:
|
153 |
pose_seq = np.load(config['pose_temp'])
|
154 |
mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
|
@@ -159,7 +161,7 @@ class Processer():
|
|
159 |
|
160 |
pose_images = []
|
161 |
for i, verts in enumerate(projected_vertices):
|
162 |
-
lmk_img =
|
163 |
pose_images.append(lmk_img)
|
164 |
|
165 |
pose_list = []
|
@@ -210,6 +212,9 @@ class Processer():
|
|
210 |
@spaces.GPU
|
211 |
def video2video(self, ref_img, source_video, size=512, steps=25, length=150, seed=42):
|
212 |
cfg = 3.5
|
|
|
|
|
|
|
213 |
|
214 |
generator = torch.manual_seed(seed)
|
215 |
width, height = size, size
|
@@ -222,19 +227,19 @@ class Processer():
|
|
222 |
save_dir.mkdir(exist_ok=True, parents=True)
|
223 |
|
224 |
ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
|
225 |
-
ref_image_np = crop_face(ref_image_np,
|
226 |
if ref_image_np is None:
|
227 |
return None, Image.fromarray(ref_img)
|
228 |
|
229 |
ref_image_np = cv2.resize(ref_image_np, (size, size))
|
230 |
ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
|
231 |
|
232 |
-
face_result =
|
233 |
if face_result is None:
|
234 |
return None, ref_image_pil
|
235 |
|
236 |
lmks = face_result['lmks'].astype(np.float32)
|
237 |
-
ref_pose =
|
238 |
|
239 |
source_images = read_frames(source_video)
|
240 |
src_fps = get_fps(source_video)
|
@@ -257,7 +262,7 @@ class Processer():
|
|
257 |
src_tensor_list.append(pose_transform(src_image_pil))
|
258 |
src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
|
259 |
frame_height, frame_width, _ = src_img_np.shape
|
260 |
-
src_img_result =
|
261 |
if src_img_result is None:
|
262 |
break
|
263 |
pose_trans_list.append(src_img_result['trans_mat'])
|
@@ -291,7 +296,7 @@ class Processer():
|
|
291 |
|
292 |
pose_list = []
|
293 |
for i, verts in enumerate(projected_vertices):
|
294 |
-
lmk_img =
|
295 |
pose_image_np = cv2.resize(lmk_img, (width, height))
|
296 |
pose_list.append(pose_image_np)
|
297 |
|
|
|
33 |
|
34 |
class Processer():
|
35 |
def __init__(self):
|
36 |
+
self.a2m_model, self.pipe = self.create_models()
|
37 |
+
|
38 |
@spaces.GPU
|
39 |
def create_models(self):
|
40 |
|
|
|
|
|
|
|
41 |
config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
|
42 |
|
43 |
if config.weight_dtype == "fp16":
|
|
|
47 |
|
48 |
audio_infer_config = OmegaConf.load(config.audio_inference_config)
|
49 |
# prepare model
|
50 |
+
a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
|
51 |
+
a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
|
52 |
+
a2m_model.to("cuda").eval()
|
53 |
|
54 |
+
vae = AutoencoderKL.from_pretrained(
|
55 |
config.pretrained_vae_path,
|
56 |
).to("cuda", dtype=weight_dtype)
|
57 |
|
58 |
+
reference_unet = UNet2DConditionModel.from_pretrained(
|
59 |
config.pretrained_base_model_path,
|
60 |
subfolder="unet",
|
61 |
).to(dtype=weight_dtype, device="cuda")
|
62 |
|
63 |
inference_config_path = config.inference_config
|
64 |
infer_config = OmegaConf.load(inference_config_path)
|
65 |
+
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
66 |
config.pretrained_base_model_path,
|
67 |
config.motion_module_path,
|
68 |
subfolder="unet",
|
69 |
unet_additional_kwargs=infer_config.unet_additional_kwargs,
|
70 |
).to(dtype=weight_dtype, device="cuda")
|
71 |
|
72 |
+
pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
|
73 |
|
74 |
+
image_enc = CLIPVisionModelWithProjection.from_pretrained(
|
75 |
config.image_encoder_path
|
76 |
).to(dtype=weight_dtype, device="cuda")
|
77 |
|
78 |
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
|
79 |
+
scheduler = DDIMScheduler(**sched_kwargs)
|
80 |
|
81 |
# load pretrained weights
|
82 |
+
denoising_unet.load_state_dict(
|
83 |
torch.load(config.denoising_unet_path, map_location="cpu"),
|
84 |
strict=False,
|
85 |
)
|
86 |
+
reference_unet.load_state_dict(
|
87 |
torch.load(config.reference_unet_path, map_location="cpu"),
|
88 |
)
|
89 |
+
pose_guider.load_state_dict(
|
90 |
torch.load(config.pose_guider_path, map_location="cpu"),
|
91 |
)
|
92 |
|
93 |
+
pipe = Pose2VideoPipeline(
|
94 |
+
vae=vae,
|
95 |
+
image_encoder=image_enc,
|
96 |
+
reference_unet=reference_unet,
|
97 |
+
denoising_unet=denoising_unet,
|
98 |
+
pose_guider=pose_guider,
|
99 |
+
scheduler=scheduler,
|
100 |
)
|
101 |
+
pipe = pipe.to("cuda", dtype=weight_dtype)
|
102 |
+
|
103 |
+
return a2m_model, pipe
|
104 |
|
105 |
|
106 |
@spaces.GPU
|
107 |
def audio2video(self, input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42):
|
108 |
fps = 30
|
109 |
cfg = 3.5
|
110 |
+
|
111 |
+
lmk_extractor = LMKExtractor()
|
112 |
+
vis = FaceMeshVisualizer()
|
113 |
|
114 |
config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
|
115 |
audio_infer_config = OmegaConf.load(config.audio_inference_config)
|
|
|
125 |
save_dir.mkdir(exist_ok=True, parents=True)
|
126 |
|
127 |
ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
|
128 |
+
ref_image_np = crop_face(ref_image_np, lmk_extractor)
|
129 |
if ref_image_np is None:
|
130 |
return None, Image.fromarray(ref_img)
|
131 |
|
132 |
ref_image_np = cv2.resize(ref_image_np, (size, size))
|
133 |
ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
|
134 |
|
135 |
+
face_result = lmk_extractor(ref_image_np)
|
136 |
if face_result is None:
|
137 |
return None, ref_image_pil
|
138 |
|
139 |
lmks = face_result['lmks'].astype(np.float32)
|
140 |
+
ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
|
141 |
|
142 |
sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
|
143 |
sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
|
|
|
150 |
pred = pred + face_result['lmks3d']
|
151 |
|
152 |
if headpose_video is not None:
|
153 |
+
pose_seq = get_headpose_temp(headpose_video, lmk_extractor)
|
154 |
else:
|
155 |
pose_seq = np.load(config['pose_temp'])
|
156 |
mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
|
|
|
161 |
|
162 |
pose_images = []
|
163 |
for i, verts in enumerate(projected_vertices):
|
164 |
+
lmk_img = vis.draw_landmarks((width, height), verts, normed=False)
|
165 |
pose_images.append(lmk_img)
|
166 |
|
167 |
pose_list = []
|
|
|
212 |
@spaces.GPU
|
213 |
def video2video(self, ref_img, source_video, size=512, steps=25, length=150, seed=42):
|
214 |
cfg = 3.5
|
215 |
+
|
216 |
+
lmk_extractor = LMKExtractor()
|
217 |
+
vis = FaceMeshVisualizer()
|
218 |
|
219 |
generator = torch.manual_seed(seed)
|
220 |
width, height = size, size
|
|
|
227 |
save_dir.mkdir(exist_ok=True, parents=True)
|
228 |
|
229 |
ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
|
230 |
+
ref_image_np = crop_face(ref_image_np, lmk_extractor)
|
231 |
if ref_image_np is None:
|
232 |
return None, Image.fromarray(ref_img)
|
233 |
|
234 |
ref_image_np = cv2.resize(ref_image_np, (size, size))
|
235 |
ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
|
236 |
|
237 |
+
face_result = lmk_extractor(ref_image_np)
|
238 |
if face_result is None:
|
239 |
return None, ref_image_pil
|
240 |
|
241 |
lmks = face_result['lmks'].astype(np.float32)
|
242 |
+
ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
|
243 |
|
244 |
source_images = read_frames(source_video)
|
245 |
src_fps = get_fps(source_video)
|
|
|
262 |
src_tensor_list.append(pose_transform(src_image_pil))
|
263 |
src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
|
264 |
frame_height, frame_width, _ = src_img_np.shape
|
265 |
+
src_img_result = lmk_extractor(src_img_np)
|
266 |
if src_img_result is None:
|
267 |
break
|
268 |
pose_trans_list.append(src_img_result['trans_mat'])
|
|
|
296 |
|
297 |
pose_list = []
|
298 |
for i, verts in enumerate(projected_vertices):
|
299 |
+
lmk_img = vis.draw_landmarks((frame_width, frame_height), verts, normed=False)
|
300 |
pose_image_np = cv2.resize(lmk_img, (width, height))
|
301 |
pose_list.append(pose_image_np)
|
302 |
|