zejunyang commited on
Commit
e4de730
·
1 Parent(s): fab87df
Files changed (1) hide show
  1. src/create_modules.py +40 -35
src/create_modules.py CHANGED
@@ -33,14 +33,11 @@ from src.utils.crop_face_single import crop_face
33
 
34
  class Processer():
35
  def __init__(self):
36
- self.create_models()
37
-
38
  @spaces.GPU
39
  def create_models(self):
40
 
41
- self.lmk_extractor = LMKExtractor()
42
- self.vis = FaceMeshVisualizer(forehead_edge=False)
43
-
44
  config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
45
 
46
  if config.weight_dtype == "fp16":
@@ -50,64 +47,69 @@ class Processer():
50
 
51
  audio_infer_config = OmegaConf.load(config.audio_inference_config)
52
  # prepare model
53
- self.a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
54
- self.a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
55
- self.a2m_model.to("cuda").eval()
56
 
57
- self.vae = AutoencoderKL.from_pretrained(
58
  config.pretrained_vae_path,
59
  ).to("cuda", dtype=weight_dtype)
60
 
61
- self.reference_unet = UNet2DConditionModel.from_pretrained(
62
  config.pretrained_base_model_path,
63
  subfolder="unet",
64
  ).to(dtype=weight_dtype, device="cuda")
65
 
66
  inference_config_path = config.inference_config
67
  infer_config = OmegaConf.load(inference_config_path)
68
- self.denoising_unet = UNet3DConditionModel.from_pretrained_2d(
69
  config.pretrained_base_model_path,
70
  config.motion_module_path,
71
  subfolder="unet",
72
  unet_additional_kwargs=infer_config.unet_additional_kwargs,
73
  ).to(dtype=weight_dtype, device="cuda")
74
 
75
- self.pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
76
 
77
- self.image_enc = CLIPVisionModelWithProjection.from_pretrained(
78
  config.image_encoder_path
79
  ).to(dtype=weight_dtype, device="cuda")
80
 
81
  sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
82
- self.scheduler = DDIMScheduler(**sched_kwargs)
83
 
84
  # load pretrained weights
85
- self.denoising_unet.load_state_dict(
86
  torch.load(config.denoising_unet_path, map_location="cpu"),
87
  strict=False,
88
  )
89
- self.reference_unet.load_state_dict(
90
  torch.load(config.reference_unet_path, map_location="cpu"),
91
  )
92
- self.pose_guider.load_state_dict(
93
  torch.load(config.pose_guider_path, map_location="cpu"),
94
  )
95
 
96
- self.pipe = Pose2VideoPipeline(
97
- vae=self.vae,
98
- image_encoder=self.image_enc,
99
- reference_unet=self.reference_unet,
100
- denoising_unet=self.denoising_unet,
101
- pose_guider=self.pose_guider,
102
- scheduler=self.scheduler,
103
  )
104
- self.pipe = self.pipe.to("cuda", dtype=weight_dtype)
 
 
105
 
106
 
107
  @spaces.GPU
108
  def audio2video(self, input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42):
109
  fps = 30
110
  cfg = 3.5
 
 
 
111
 
112
  config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
113
  audio_infer_config = OmegaConf.load(config.audio_inference_config)
@@ -123,19 +125,19 @@ class Processer():
123
  save_dir.mkdir(exist_ok=True, parents=True)
124
 
125
  ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
126
- ref_image_np = crop_face(ref_image_np, self.lmk_extractor)
127
  if ref_image_np is None:
128
  return None, Image.fromarray(ref_img)
129
 
130
  ref_image_np = cv2.resize(ref_image_np, (size, size))
131
  ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
132
 
133
- face_result = self.lmk_extractor(ref_image_np)
134
  if face_result is None:
135
  return None, ref_image_pil
136
 
137
  lmks = face_result['lmks'].astype(np.float32)
138
- ref_pose = self.vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
139
 
140
  sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
141
  sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
@@ -148,7 +150,7 @@ class Processer():
148
  pred = pred + face_result['lmks3d']
149
 
150
  if headpose_video is not None:
151
- pose_seq = get_headpose_temp(headpose_video, self.lmk_extractor)
152
  else:
153
  pose_seq = np.load(config['pose_temp'])
154
  mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
@@ -159,7 +161,7 @@ class Processer():
159
 
160
  pose_images = []
161
  for i, verts in enumerate(projected_vertices):
162
- lmk_img = self.vis.draw_landmarks((width, height), verts, normed=False)
163
  pose_images.append(lmk_img)
164
 
165
  pose_list = []
@@ -210,6 +212,9 @@ class Processer():
210
  @spaces.GPU
211
  def video2video(self, ref_img, source_video, size=512, steps=25, length=150, seed=42):
212
  cfg = 3.5
 
 
 
213
 
214
  generator = torch.manual_seed(seed)
215
  width, height = size, size
@@ -222,19 +227,19 @@ class Processer():
222
  save_dir.mkdir(exist_ok=True, parents=True)
223
 
224
  ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
225
- ref_image_np = crop_face(ref_image_np, self.lmk_extractor)
226
  if ref_image_np is None:
227
  return None, Image.fromarray(ref_img)
228
 
229
  ref_image_np = cv2.resize(ref_image_np, (size, size))
230
  ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
231
 
232
- face_result = self.lmk_extractor(ref_image_np)
233
  if face_result is None:
234
  return None, ref_image_pil
235
 
236
  lmks = face_result['lmks'].astype(np.float32)
237
- ref_pose = self.vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
238
 
239
  source_images = read_frames(source_video)
240
  src_fps = get_fps(source_video)
@@ -257,7 +262,7 @@ class Processer():
257
  src_tensor_list.append(pose_transform(src_image_pil))
258
  src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
259
  frame_height, frame_width, _ = src_img_np.shape
260
- src_img_result = self.lmk_extractor(src_img_np)
261
  if src_img_result is None:
262
  break
263
  pose_trans_list.append(src_img_result['trans_mat'])
@@ -291,7 +296,7 @@ class Processer():
291
 
292
  pose_list = []
293
  for i, verts in enumerate(projected_vertices):
294
- lmk_img = self.vis.draw_landmarks((frame_width, frame_height), verts, normed=False)
295
  pose_image_np = cv2.resize(lmk_img, (width, height))
296
  pose_list.append(pose_image_np)
297
 
 
33
 
34
  class Processer():
35
  def __init__(self):
36
+ self.a2m_model, self.pipe = self.create_models()
37
+
38
  @spaces.GPU
39
  def create_models(self):
40
 
 
 
 
41
  config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
42
 
43
  if config.weight_dtype == "fp16":
 
47
 
48
  audio_infer_config = OmegaConf.load(config.audio_inference_config)
49
  # prepare model
50
+ a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
51
+ a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
52
+ a2m_model.to("cuda").eval()
53
 
54
+ vae = AutoencoderKL.from_pretrained(
55
  config.pretrained_vae_path,
56
  ).to("cuda", dtype=weight_dtype)
57
 
58
+ reference_unet = UNet2DConditionModel.from_pretrained(
59
  config.pretrained_base_model_path,
60
  subfolder="unet",
61
  ).to(dtype=weight_dtype, device="cuda")
62
 
63
  inference_config_path = config.inference_config
64
  infer_config = OmegaConf.load(inference_config_path)
65
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
66
  config.pretrained_base_model_path,
67
  config.motion_module_path,
68
  subfolder="unet",
69
  unet_additional_kwargs=infer_config.unet_additional_kwargs,
70
  ).to(dtype=weight_dtype, device="cuda")
71
 
72
+ pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
73
 
74
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
75
  config.image_encoder_path
76
  ).to(dtype=weight_dtype, device="cuda")
77
 
78
  sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
79
+ scheduler = DDIMScheduler(**sched_kwargs)
80
 
81
  # load pretrained weights
82
+ denoising_unet.load_state_dict(
83
  torch.load(config.denoising_unet_path, map_location="cpu"),
84
  strict=False,
85
  )
86
+ reference_unet.load_state_dict(
87
  torch.load(config.reference_unet_path, map_location="cpu"),
88
  )
89
+ pose_guider.load_state_dict(
90
  torch.load(config.pose_guider_path, map_location="cpu"),
91
  )
92
 
93
+ pipe = Pose2VideoPipeline(
94
+ vae=vae,
95
+ image_encoder=image_enc,
96
+ reference_unet=reference_unet,
97
+ denoising_unet=denoising_unet,
98
+ pose_guider=pose_guider,
99
+ scheduler=scheduler,
100
  )
101
+ pipe = pipe.to("cuda", dtype=weight_dtype)
102
+
103
+ return a2m_model, pipe
104
 
105
 
106
  @spaces.GPU
107
  def audio2video(self, input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42):
108
  fps = 30
109
  cfg = 3.5
110
+
111
+ lmk_extractor = LMKExtractor()
112
+ vis = FaceMeshVisualizer()
113
 
114
  config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
115
  audio_infer_config = OmegaConf.load(config.audio_inference_config)
 
125
  save_dir.mkdir(exist_ok=True, parents=True)
126
 
127
  ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
128
+ ref_image_np = crop_face(ref_image_np, lmk_extractor)
129
  if ref_image_np is None:
130
  return None, Image.fromarray(ref_img)
131
 
132
  ref_image_np = cv2.resize(ref_image_np, (size, size))
133
  ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
134
 
135
+ face_result = lmk_extractor(ref_image_np)
136
  if face_result is None:
137
  return None, ref_image_pil
138
 
139
  lmks = face_result['lmks'].astype(np.float32)
140
+ ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
141
 
142
  sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
143
  sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
 
150
  pred = pred + face_result['lmks3d']
151
 
152
  if headpose_video is not None:
153
+ pose_seq = get_headpose_temp(headpose_video, lmk_extractor)
154
  else:
155
  pose_seq = np.load(config['pose_temp'])
156
  mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
 
161
 
162
  pose_images = []
163
  for i, verts in enumerate(projected_vertices):
164
+ lmk_img = vis.draw_landmarks((width, height), verts, normed=False)
165
  pose_images.append(lmk_img)
166
 
167
  pose_list = []
 
212
  @spaces.GPU
213
  def video2video(self, ref_img, source_video, size=512, steps=25, length=150, seed=42):
214
  cfg = 3.5
215
+
216
+ lmk_extractor = LMKExtractor()
217
+ vis = FaceMeshVisualizer()
218
 
219
  generator = torch.manual_seed(seed)
220
  width, height = size, size
 
227
  save_dir.mkdir(exist_ok=True, parents=True)
228
 
229
  ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
230
+ ref_image_np = crop_face(ref_image_np, lmk_extractor)
231
  if ref_image_np is None:
232
  return None, Image.fromarray(ref_img)
233
 
234
  ref_image_np = cv2.resize(ref_image_np, (size, size))
235
  ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
236
 
237
+ face_result = lmk_extractor(ref_image_np)
238
  if face_result is None:
239
  return None, ref_image_pil
240
 
241
  lmks = face_result['lmks'].astype(np.float32)
242
+ ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
243
 
244
  source_images = read_frames(source_video)
245
  src_fps = get_fps(source_video)
 
262
  src_tensor_list.append(pose_transform(src_image_pil))
263
  src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
264
  frame_height, frame_width, _ = src_img_np.shape
265
+ src_img_result = lmk_extractor(src_img_np)
266
  if src_img_result is None:
267
  break
268
  pose_trans_list.append(src_img_result['trans_mat'])
 
296
 
297
  pose_list = []
298
  for i, verts in enumerate(projected_vertices):
299
+ lmk_img = vis.draw_landmarks((frame_width, frame_height), verts, normed=False)
300
  pose_image_np = cv2.resize(lmk_img, (width, height))
301
  pose_list.append(pose_image_np)
302