waveydaveygravy commited on
Commit
9531098
1 Parent(s): 094e6e0

Upload 6 files

Browse files
bo_1resized.jpg ADDED
bo_1resized.mp4 ADDED
Binary file (991 kB). View file
 
bo_1resized_ang_bo_1resized.mp4 ADDED
Binary file (156 kB). View file
 
demoworking.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title demo.py with fixed paths
2
+
3
+ import os
4
+ import numpy as np
5
+ import torch
6
+ import yaml
7
+ from modules.generator import OcclusionAwareSPADEGeneratorEam
8
+ from modules.keypoint_detector import KPDetector, HEEstimator
9
+ import argparse
10
+ import imageio
11
+ from modules.transformer import Audio2kpTransformerBBoxQDeepPrompt as Audio2kpTransformer
12
+ from modules.prompt import EmotionDeepPrompt, EmotionalDeformationTransformer
13
+ from scipy.io import wavfile
14
+
15
+ from modules.model_transformer import get_rotation_matrix, keypoint_transformation
16
+ from skimage import io, img_as_float32
17
+ from skimage.transform import resize
18
+ import torchaudio
19
+ import soundfile as sf
20
+ from scipy.spatial import ConvexHull
21
+
22
+ import torch.nn.functional as F
23
+ import glob
24
+ from tqdm import tqdm
25
+ import gzip
26
+
27
+ emo_label = ['ang', 'con', 'dis', 'fea', 'hap', 'neu', 'sad', 'sur']
28
+ emo_label_full = ['angry', 'contempt', 'disgusted', 'fear', 'happy', 'neutral', 'sad', 'surprised']
29
+ latent_dim = 16
30
+
31
+ MEL_PARAMS_25 = {
32
+ "n_mels": 80,
33
+ "n_fft": 2048,
34
+ "win_length": 640,
35
+ "hop_length": 640
36
+ }
37
+
38
+ to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS_25)
39
+ mean, std = -4, 4
40
+
41
+ expU = torch.from_numpy(np.load('/content/EAT_code/expPCAnorm_fin/U_mead.npy')[:,:32])
42
+ expmean = torch.from_numpy(np.load('/content/EAT_code/expPCAnorm_fin/mean_mead.npy'))
43
+
44
+ root_wav = '/content/EAT_code/demo/video_processed/bo_1resized'
45
+ def normalize_kp(kp_source, kp_driving, kp_driving_initial,
46
+ use_relative_movement=True, use_relative_jacobian=True):
47
+
48
+ kp_new = {k: v for k, v in kp_driving.items()}
49
+ if use_relative_movement:
50
+ kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
51
+ kp_new['value'] = kp_value_diff + kp_source['value']
52
+
53
+ if use_relative_jacobian:
54
+ jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
55
+ kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
56
+
57
+ return kp_new
58
+
59
+ def _load_tensor(data):
60
+ wave_path = data
61
+ wave, sr = sf.read(wave_path)
62
+ wave_tensor = torch.from_numpy(wave).float()
63
+ return wave_tensor
64
+
65
+ def build_model(config, device_ids=[0]):
66
+ generator = OcclusionAwareSPADEGeneratorEam(**config['model_params']['generator_params'],
67
+ **config['model_params']['common_params'])
68
+ if torch.cuda.is_available():
69
+ print('cuda is available')
70
+ generator.to(device_ids[0])
71
+
72
+ kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
73
+ **config['model_params']['common_params'])
74
+
75
+ if torch.cuda.is_available():
76
+ kp_detector.to(device_ids[0])
77
+
78
+
79
+ audio2kptransformer = Audio2kpTransformer(**config['model_params']['audio2kp_params'], face_ea=True)
80
+
81
+ if torch.cuda.is_available():
82
+ audio2kptransformer.to(device_ids[0])
83
+
84
+ sidetuning = EmotionalDeformationTransformer(**config['model_params']['audio2kp_params'])
85
+
86
+ if torch.cuda.is_available():
87
+ sidetuning.to(device_ids[0])
88
+
89
+ emotionprompt = EmotionDeepPrompt()
90
+
91
+ if torch.cuda.is_available():
92
+ emotionprompt.to(device_ids[0])
93
+
94
+ return generator, kp_detector, audio2kptransformer, sidetuning, emotionprompt
95
+
96
+
97
+ def prepare_test_data(img_path, audio_path, opt, emotype, use_otherimg=True):
98
+ # sr,_ = wavfile.read(audio_path)
99
+
100
+ if use_otherimg:
101
+ source_latent = np.load(img_path.replace('cropped', 'latent')[:-4]+'.npy', allow_pickle=True)
102
+ else:
103
+ source_latent = np.load(img_path.replace('images', 'latent')[:-9]+'.npy', allow_pickle=True)
104
+ he_source = {}
105
+ for k in source_latent[1].keys():
106
+ he_source[k] = torch.from_numpy(source_latent[1][k][0]).unsqueeze(0).cuda()
107
+
108
+ # source images
109
+ source_img = img_as_float32(io.imread(img_path)).transpose((2, 0, 1))
110
+ asp = os.path.basename(audio_path)[:-4]
111
+
112
+ # latent code
113
+ y_trg = emo_label.index(emotype)
114
+ z_trg = torch.randn(latent_dim)
115
+
116
+ # driving latent
117
+ latent_path_driving = f'{root_wav}/latent_evp_25/{asp}.npy'
118
+ pose_gz = gzip.GzipFile(f'{root_wav}/poseimg/{asp}.npy.gz', 'r')
119
+ poseimg = np.load(pose_gz)
120
+ deepfeature = np.load(f'{root_wav}/deepfeature32/{asp}.npy')
121
+ driving_latent = np.load(latent_path_driving[:-4]+'.npy', allow_pickle=True)
122
+ he_driving = driving_latent[1]
123
+
124
+ # gt frame number
125
+ frames = glob.glob(f'{root_wav}/images_evp_25/cropped/*.jpg')
126
+ num_frames = len(frames)
127
+
128
+ wave_tensor = _load_tensor(audio_path)
129
+ if len(wave_tensor.shape) > 1:
130
+ wave_tensor = wave_tensor[:, 0]
131
+ mel_tensor = to_melspec(wave_tensor)
132
+ mel_tensor = (torch.log(1e-5 + mel_tensor) - mean) / std
133
+ name_len = min(mel_tensor.shape[1], poseimg.shape[0], deepfeature.shape[0])
134
+
135
+ audio_frames = []
136
+ poseimgs = []
137
+ deep_feature = []
138
+
139
+ pad, deep_pad = np.load('/content/EAT_code/pad.npy', allow_pickle=True)
140
+
141
+ if name_len < num_frames:
142
+ diff = num_frames - name_len
143
+ if diff > 2:
144
+ print(f"Attention: the frames are {diff} more than name_len, we will use name_len to replace num_frames")
145
+ num_frames=name_len
146
+ for k in he_driving.keys():
147
+ he_driving[k] = he_driving[k][:name_len, :]
148
+ for rid in range(0, num_frames):
149
+ audio = []
150
+ poses = []
151
+ deeps = []
152
+ for i in range(rid - opt['num_w'], rid + opt['num_w'] + 1):
153
+ if i < 0:
154
+ audio.append(pad)
155
+ poses.append(poseimg[0])
156
+ deeps.append(deep_pad)
157
+ elif i >= name_len:
158
+ audio.append(pad)
159
+ poses.append(poseimg[-1])
160
+ deeps.append(deep_pad)
161
+ else:
162
+ audio.append(mel_tensor[:, i])
163
+ poses.append(poseimg[i])
164
+ deeps.append(deepfeature[i])
165
+
166
+ audio_frames.append(torch.stack(audio, dim=1))
167
+ poseimgs.append(poses)
168
+ deep_feature.append(deeps)
169
+ audio_frames = torch.stack(audio_frames, dim=0)
170
+ poseimgs = torch.from_numpy(np.array(poseimgs))
171
+ deep_feature = torch.from_numpy(np.array(deep_feature)).to(torch.float)
172
+ return audio_frames, poseimgs, deep_feature, source_img, he_source, he_driving, num_frames, y_trg, z_trg, latent_path_driving
173
+
174
+ def load_ckpt(ckpt, kp_detector, generator, audio2kptransformer, sidetuning, emotionprompt):
175
+ checkpoint = torch.load(ckpt, map_location=torch.device('cpu'))
176
+ if audio2kptransformer is not None:
177
+ audio2kptransformer.load_state_dict(checkpoint['audio2kptransformer'])
178
+ if generator is not None:
179
+ generator.load_state_dict(checkpoint['generator'])
180
+ if kp_detector is not None:
181
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
182
+ if sidetuning is not None:
183
+ sidetuning.load_state_dict(checkpoint['sidetuning'])
184
+ if emotionprompt is not None:
185
+ emotionprompt.load_state_dict(checkpoint['emotionprompt'])
186
+
187
+ import cv2
188
+ import dlib
189
+ from tqdm import tqdm
190
+ from skimage import transform as tf
191
+ detector = dlib.get_frontal_face_detector()
192
+ predictor = dlib.shape_predictor('/content/EAT_code/demo/shape_predictor_68_face_landmarks.dat')
193
+
194
+ def shape_to_np(shape, dtype="int"):
195
+ # initialize the list of (x, y)-coordinates
196
+ coords = np.zeros((shape.num_parts, 2), dtype=dtype)
197
+
198
+ # loop over all facial landmarks and convert them
199
+ # to a 2-tuple of (x, y)-coordinates
200
+ for i in range(0, shape.num_parts):
201
+ coords[i] = (shape.part(i).x, shape.part(i).y)
202
+
203
+ # return the list of (x, y)-coordinates
204
+ return coords
205
+
206
+ def crop_image(image_path, out_path):
207
+ template = np.load('/content/EAT_code/demo/bo_1resized_template.npy')
208
+ image = cv2.imread(image_path)
209
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
210
+ rects = detector(gray, 1) #detect human face
211
+ if len(rects) != 1:
212
+ return 0
213
+ for (j, rect) in enumerate(rects):
214
+ shape = predictor(gray, rect) #detect 68 points
215
+ shape = shape_to_np(shape)
216
+
217
+ pts2 = np.float32(template[:47,:])
218
+ pts1 = np.float32(shape[:47,:]) #eye and nose
219
+ tform = tf.SimilarityTransform()
220
+ tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters.
221
+
222
+ dst = tf.warp(image, tform, output_shape=(256, 256))
223
+
224
+ dst = np.array(dst * 255, dtype=np.uint8)
225
+
226
+ cv2.imwrite(out_path, dst)
227
+
228
+ def preprocess_imgs(allimgs, tmp_allimgs_cropped):
229
+ name_cropped = []
230
+ for path in tmp_allimgs_cropped:
231
+ name_cropped.append(os.path.basename(path))
232
+ for path in allimgs:
233
+ if os.path.basename(path) in name_cropped:
234
+ continue
235
+ else:
236
+ out_path = path.replace('imgs1/', 'imgs_cropped1/')
237
+ crop_image(path, out_path)
238
+
239
+ from sync_batchnorm import DataParallelWithCallback
240
+ def load_checkpoints_extractor(config_path, checkpoint_path, cpu=False):
241
+
242
+ with open(config_path) as f:
243
+ config = yaml.load(f, Loader=yaml.FullLoader)
244
+
245
+ kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
246
+ **config['model_params']['common_params'])
247
+ if not cpu:
248
+ kp_detector.cuda()
249
+
250
+ he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
251
+ **config['model_params']['common_params'])
252
+ if not cpu:
253
+ he_estimator.cuda()
254
+
255
+ if cpu:
256
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
257
+ else:
258
+ checkpoint = torch.load(checkpoint_path)
259
+
260
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
261
+ he_estimator.load_state_dict(checkpoint['he_estimator'])
262
+
263
+ if not cpu:
264
+ kp_detector = DataParallelWithCallback(kp_detector)
265
+ he_estimator = DataParallelWithCallback(he_estimator)
266
+
267
+ kp_detector.eval()
268
+ he_estimator.eval()
269
+
270
+ return kp_detector, he_estimator
271
+
272
+ def estimate_latent(driving_video, kp_detector, he_estimator):
273
+ with torch.no_grad():
274
+ predictions = []
275
+ driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).cuda()
276
+ kp_canonical = kp_detector(driving[:, :, 0])
277
+ he_drivings = {'yaw': [], 'pitch': [], 'roll': [], 't': [], 'exp': []}
278
+
279
+ for frame_idx in range(driving.shape[2]):
280
+ driving_frame = driving[:, :, frame_idx]
281
+ he_driving = he_estimator(driving_frame)
282
+ for k in he_drivings.keys():
283
+ he_drivings[k].append(he_driving[k])
284
+ return [kp_canonical, he_drivings]
285
+
286
+ def extract_keypoints(extract_list):
287
+ kp_detector, he_estimator = load_checkpoints_extractor(config_path='/content/EAT_code/config/vox-256-spade.yaml', checkpoint_path='/content/EAT_code/ckpt/pretrain_new_274.pth.tar')
288
+ if not os.path.exists('./demo/imgs_latent/'):
289
+ os.makedirs('./demo/imgs_latent/')
290
+ for imgname in tqdm(extract_list):
291
+ path_frames = [imgname]
292
+ filesname=os.path.basename(imgname)[:-4]
293
+ if os.path.exists(f'./demo/imgs_latent/'+filesname+'.npy'):
294
+ continue
295
+ driving_frames = []
296
+ for im in path_frames:
297
+ driving_frames.append(imageio.imread(im))
298
+ driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_frames]
299
+
300
+ kc, he = estimate_latent(driving_video, kp_detector, he_estimator)
301
+ kc = kc['value'].cpu().numpy()
302
+ for k in he:
303
+ he[k] = torch.cat(he[k]).cpu().numpy()
304
+ np.save('./demo/imgs_latent/'+filesname, [kc, he])
305
+
306
+ def preprocess_cropped_imgs(allimgs_cropped):
307
+ extract_list = []
308
+ for img_path in allimgs_cropped:
309
+ if not os.path.exists(img_path.replace('cropped', 'latent')[:-4]+'.npy'):
310
+ extract_list.append(img_path)
311
+ if len(extract_list) > 0:
312
+ print('=========', "Extract latent keypoints from New image", '======')
313
+ extract_keypoints(extract_list)
314
+
315
+ def test(ckpt, emotype, save_dir=" "):
316
+ # with open("config/vox-transformer2.yaml") as f:
317
+ with open("/content/EAT_code/config/deepprompt_eam3d_st_tanh_304_3090_all.yaml") as f:
318
+ config = yaml.load(f, Loader=yaml.FullLoader)
319
+ cur_path = os.getcwd()
320
+ generator, kp_detector, audio2kptransformer, sidetuning, emotionprompt = build_model(config)
321
+ load_ckpt(ckpt, kp_detector=kp_detector, generator=generator, audio2kptransformer=audio2kptransformer, sidetuning=sidetuning, emotionprompt=emotionprompt)
322
+
323
+ audio2kptransformer.eval()
324
+ generator.eval()
325
+ kp_detector.eval()
326
+ sidetuning.eval()
327
+ emotionprompt.eval()
328
+
329
+ all_wavs2 = [f'{root_wav}/{os.path.basename(root_wav)}.wav']
330
+ allimg = glob.glob('/content/EAT_code/demo/imgs1/*.jpg')
331
+ tmp_allimg_cropped = glob.glob('/content/EAT_code/demo/imgs_cropped1/*.jpg')
332
+ preprocess_imgs(allimg, tmp_allimg_cropped) # crop and align images
333
+
334
+ allimg_cropped = glob.glob('/content/EAT_code/demo/imgs_cropped1/*.jpg')
335
+ preprocess_cropped_imgs(allimg_cropped) # extract latent keypoints if necessary
336
+
337
+ for ind in tqdm(range(len(all_wavs2))):
338
+ for img_path in tqdm(allimg_cropped):
339
+ audio_path = all_wavs2[ind]
340
+ # read in data
341
+ audio_frames, poseimgs, deep_feature, source_img, he_source, he_driving, num_frames, y_trg, z_trg, latent_path_driving = prepare_test_data(img_path, audio_path, config['model_params']['audio2kp_params'], emotype)
342
+
343
+
344
+ with torch.no_grad():
345
+ source_img = torch.from_numpy(source_img).unsqueeze(0).cuda()
346
+ kp_canonical = kp_detector(source_img, with_feature=True) # {'value': value, 'jacobian': jacobian}
347
+ kp_cano = kp_canonical['value']
348
+
349
+ x = {}
350
+ x['mel'] = audio_frames.unsqueeze(1).unsqueeze(0).cuda()
351
+ x['z_trg'] = z_trg.unsqueeze(0).cuda()
352
+ x['y_trg'] = torch.tensor(y_trg, dtype=torch.long).cuda().reshape(1)
353
+ x['pose'] = poseimgs.cuda()
354
+ x['deep'] = deep_feature.cuda().unsqueeze(0)
355
+ x['he_driving'] = {'yaw': torch.from_numpy(he_driving['yaw']).cuda().unsqueeze(0),
356
+ 'pitch': torch.from_numpy(he_driving['pitch']).cuda().unsqueeze(0),
357
+ 'roll': torch.from_numpy(he_driving['roll']).cuda().unsqueeze(0),
358
+ 't': torch.from_numpy(he_driving['t']).cuda().unsqueeze(0),
359
+ }
360
+
361
+ ### emotion prompt
362
+ emoprompt, deepprompt = emotionprompt(x)
363
+ a2kp_exps = []
364
+ emo_exps = []
365
+ T = 5
366
+ if T == 1:
367
+ for i in range(x['mel'].shape[1]):
368
+ xi = {}
369
+ xi['mel'] = x['mel'][:,i,:,:,:].unsqueeze(1)
370
+ xi['z_trg'] = x['z_trg']
371
+ xi['y_trg'] = x['y_trg']
372
+ xi['pose'] = x['pose'][i,:,:,:,:].unsqueeze(0)
373
+ xi['deep'] = x['deep'][:,i,:,:,:].unsqueeze(1)
374
+ xi['he_driving'] = {'yaw': x['he_driving']['yaw'][:,i,:].unsqueeze(0),
375
+ 'pitch': x['he_driving']['pitch'][:,i,:].unsqueeze(0),
376
+ 'roll': x['he_driving']['roll'][:,i,:].unsqueeze(0),
377
+ 't': x['he_driving']['t'][:,i,:].unsqueeze(0),
378
+ }
379
+ he_driving_emo_xi, input_st_xi = audio2kptransformer(xi, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
380
+ emo_exp = sidetuning(input_st_xi, emoprompt, deepprompt)
381
+ a2kp_exps.append(he_driving_emo_xi['emo'])
382
+ emo_exps.append(emo_exp)
383
+ elif T is not None:
384
+ for i in range(x['mel'].shape[1]//T+1):
385
+ if i*T >= x['mel'].shape[1]:
386
+ break
387
+ xi = {}
388
+ xi['mel'] = x['mel'][:,i*T:(i+1)*T,:,:,:]
389
+ xi['z_trg'] = x['z_trg']
390
+ xi['y_trg'] = x['y_trg']
391
+ xi['pose'] = x['pose'][i*T:(i+1)*T,:,:,:,:]
392
+ xi['deep'] = x['deep'][:,i*T:(i+1)*T,:,:,:]
393
+ xi['he_driving'] = {'yaw': x['he_driving']['yaw'][:,i*T:(i+1)*T,:],
394
+ 'pitch': x['he_driving']['pitch'][:,i*T:(i+1)*T,:],
395
+ 'roll': x['he_driving']['roll'][:,i*T:(i+1)*T,:],
396
+ 't': x['he_driving']['t'][:,i*T:(i+1)*T,:],
397
+ }
398
+ he_driving_emo_xi, input_st_xi = audio2kptransformer(xi, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
399
+ emo_exp = sidetuning(input_st_xi, emoprompt, deepprompt)
400
+ a2kp_exps.append(he_driving_emo_xi['emo'])
401
+ emo_exps.append(emo_exp)
402
+
403
+ if T is None:
404
+ he_driving_emo, input_st = audio2kptransformer(x, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
405
+ emo_exps = sidetuning(input_st, emoprompt, deepprompt).reshape(-1, 45)
406
+ else:
407
+ he_driving_emo = {}
408
+ he_driving_emo['emo'] = torch.cat(a2kp_exps, dim=0)
409
+ emo_exps = torch.cat(emo_exps, dim=0).reshape(-1, 45)
410
+
411
+ exp = he_driving_emo['emo']
412
+ device = exp.get_device()
413
+ exp = torch.mm(exp, expU.t().to(device))
414
+ exp = exp + expmean.expand_as(exp).to(device)
415
+ exp = exp + emo_exps
416
+
417
+
418
+ source_area = ConvexHull(kp_cano[0].cpu().numpy()).volume
419
+ exp = exp * source_area
420
+
421
+ he_new_driving = {'yaw': torch.from_numpy(he_driving['yaw']).cuda(),
422
+ 'pitch': torch.from_numpy(he_driving['pitch']).cuda(),
423
+ 'roll': torch.from_numpy(he_driving['roll']).cuda(),
424
+ 't': torch.from_numpy(he_driving['t']).cuda(),
425
+ 'exp': exp}
426
+ he_driving['exp'] = torch.from_numpy(he_driving['exp']).cuda()
427
+
428
+ kp_source = keypoint_transformation(kp_canonical, he_source, False)
429
+ mean_source = torch.mean(kp_source['value'], dim=1)[0]
430
+ kp_driving = keypoint_transformation(kp_canonical, he_new_driving, False)
431
+ mean_driving = torch.mean(torch.mean(kp_driving['value'], dim=1), dim=0)
432
+ kp_driving['value'] = kp_driving['value']+(mean_source-mean_driving).unsqueeze(0).unsqueeze(0)
433
+ bs = kp_source['value'].shape[0]
434
+ predictions_gen = []
435
+ for i in tqdm(range(num_frames)):
436
+ kp_si = {}
437
+ kp_si['value'] = kp_source['value'][0].unsqueeze(0)
438
+ kp_di = {}
439
+ kp_di['value'] = kp_driving['value'][i].unsqueeze(0)
440
+ generated = generator(source_img, kp_source=kp_si, kp_driving=kp_di, prompt=emoprompt)
441
+ predictions_gen.append(
442
+ (np.transpose(generated['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0] * 255).astype(np.uint8))
443
+
444
+ log_dir = save_dir
445
+ os.makedirs(os.path.join(log_dir, "temp"), exist_ok=True)
446
+
447
+ f_name = os.path.basename(img_path[:-4]) + "_" + emotype + "_" + os.path.basename(latent_path_driving)[:-4] + ".mp4"
448
+ video_path = os.path.join(log_dir, "temp", f_name)
449
+ imageio.mimsave(video_path, predictions_gen, fps=25.0)
450
+
451
+ save_video = os.path.join(log_dir, f_name)
452
+ cmd = r'ffmpeg -loglevel error -y -i "%s" -i "%s" -vcodec copy -shortest "%s"' % (video_path, audio_path, save_video)
453
+ os.system(cmd)
454
+ os.remove(video_path)
455
+
456
+ if __name__ == '__main__':
457
+ argparser = argparse.ArgumentParser()
458
+ argparser.add_argument("--save_dir", type=str, default="/content/EAT_code/Results ", help="path of the output video")
459
+ argparser.add_argument("--name", type=str, default="deepprompt_eam3d_all_final_313", help="path of the output video")
460
+ argparser.add_argument("--emo", type=str, default="hap", help="emotion type ('ang', 'con', 'dis', 'fea', 'hap', 'neu', 'sad', 'sur')")
461
+ argparser.add_argument("--root_wav", type=str, default='./demo/video_processed/M003_neu_1_001', help="emotion type ('ang', 'con', 'dis', 'fea', 'hap', 'neu', 'sad', 'sur')")
462
+ args = argparser.parse_args()
463
+
464
+ root_wav=args.root_wav
465
+
466
+ if len(args.name) > 1:
467
+ name = args.name
468
+ print(name)
469
+ test(f'/content/EAT_code/ckpt/deepprompt_eam3d_all_final_313.pth.tar', args.emo, save_dir=f'./demo/output/{name}/')
470
+
obama3_hap_M003_neu_1_001.mp4 ADDED
Binary file (80.9 kB). View file
 
scarlett_ang_bo_1resized.mp4 ADDED
Binary file (182 kB). View file