capstonedubtrack commited on
Commit
e382b05
1 Parent(s): 8011cec

Upload 3 files

Browse files
Wav2Lip/checkpoints/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Place all your checkpoints (.pth files) here.
Wav2Lip/inference.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import listdir, path
2
+ import numpy as np
3
+ import scipy, cv2, os, sys, argparse, audio
4
+ import json, subprocess, random, string
5
+ from tqdm import tqdm
6
+ from glob import glob
7
+ import torch, face_detection
8
+ from models import Wav2Lip
9
+ import platform
10
+
11
+ parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
12
+
13
+ parser.add_argument('--checkpoint_path', type=str,
14
+ help='Name of saved checkpoint to load weights from', required=True)
15
+
16
+ parser.add_argument('--face', type=str,
17
+ help='Filepath of video/image that contains faces to use', required=True)
18
+ parser.add_argument('--audio', type=str,
19
+ help='Filepath of video/audio file to use as raw audio source', required=True)
20
+ parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
21
+ default='results/result_voice.mp4')
22
+
23
+ parser.add_argument('--static', type=bool,
24
+ help='If True, then use only first video frame for inference', default=False)
25
+ parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
26
+ default=25., required=False)
27
+
28
+ parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
29
+ help='Padding (top, bottom, left, right). Please adjust to include chin at least')
30
+
31
+ parser.add_argument('--face_det_batch_size', type=int,
32
+ help='Batch size for face detection', default=16)
33
+ parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
34
+
35
+ parser.add_argument('--resize_factor', default=1, type=int,
36
+ help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
37
+
38
+ parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
39
+ help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
40
+ 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
41
+
42
+ parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
43
+ help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
44
+ 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
45
+
46
+ parser.add_argument('--rotate', default=False, action='store_true',
47
+ help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
48
+ 'Use if you get a flipped result, despite feeding a normal looking video')
49
+
50
+ parser.add_argument('--nosmooth', default=False, action='store_true',
51
+ help='Prevent smoothing face detections over a short temporal window')
52
+
53
+ args = parser.parse_args()
54
+ args.img_size = 96
55
+
56
+ if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
57
+ args.static = True
58
+
59
+ def get_smoothened_boxes(boxes, T):
60
+ for i in range(len(boxes)):
61
+ if i + T > len(boxes):
62
+ window = boxes[len(boxes) - T:]
63
+ else:
64
+ window = boxes[i : i + T]
65
+ boxes[i] = np.mean(window, axis=0)
66
+ return boxes
67
+
68
+ def face_detect(images):
69
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
70
+ flip_input=False, device=device)
71
+
72
+ batch_size = args.face_det_batch_size
73
+
74
+ while 1:
75
+ predictions = []
76
+ try:
77
+ for i in tqdm(range(0, len(images), batch_size)):
78
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
79
+ except RuntimeError:
80
+ if batch_size == 1:
81
+ raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
82
+ batch_size //= 2
83
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
84
+ continue
85
+ break
86
+
87
+ results = []
88
+ pady1, pady2, padx1, padx2 = args.pads
89
+ for rect, image in zip(predictions, images):
90
+ if rect is None:
91
+ cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
92
+ raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
93
+
94
+ y1 = max(0, rect[1] - pady1)
95
+ y2 = min(image.shape[0], rect[3] + pady2)
96
+ x1 = max(0, rect[0] - padx1)
97
+ x2 = min(image.shape[1], rect[2] + padx2)
98
+
99
+ results.append([x1, y1, x2, y2])
100
+
101
+ boxes = np.array(results)
102
+ if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
103
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
104
+
105
+ del detector
106
+ return results
107
+
108
+ def datagen(frames, mels):
109
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
110
+
111
+ if args.box[0] == -1:
112
+ if not args.static:
113
+ face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
114
+ else:
115
+ face_det_results = face_detect([frames[0]])
116
+ else:
117
+ print('Using the specified bounding box instead of face detection...')
118
+ y1, y2, x1, x2 = args.box
119
+ face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
120
+
121
+ for i, m in enumerate(mels):
122
+ idx = 0 if args.static else i%len(frames)
123
+ frame_to_save = frames[idx].copy()
124
+ face, coords = face_det_results[idx].copy()
125
+
126
+ face = cv2.resize(face, (args.img_size, args.img_size))
127
+
128
+ img_batch.append(face)
129
+ mel_batch.append(m)
130
+ frame_batch.append(frame_to_save)
131
+ coords_batch.append(coords)
132
+
133
+ if len(img_batch) >= args.wav2lip_batch_size:
134
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
135
+
136
+ img_masked = img_batch.copy()
137
+ img_masked[:, args.img_size//2:] = 0
138
+
139
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
140
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
141
+
142
+ yield img_batch, mel_batch, frame_batch, coords_batch
143
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
144
+
145
+ if len(img_batch) > 0:
146
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
147
+
148
+ img_masked = img_batch.copy()
149
+ img_masked[:, args.img_size//2:] = 0
150
+
151
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
152
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
153
+
154
+ yield img_batch, mel_batch, frame_batch, coords_batch
155
+
156
+ mel_step_size = 16
157
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
158
+ print('Using {} for inference.'.format(device))
159
+
160
+ def _load(checkpoint_path):
161
+ if device == 'cuda':
162
+ checkpoint = torch.load(checkpoint_path)
163
+ else:
164
+ checkpoint = torch.load(checkpoint_path,
165
+ map_location=lambda storage, loc: storage)
166
+ return checkpoint
167
+
168
+ def load_model(path):
169
+ model = Wav2Lip()
170
+ print("Load checkpoint from: {}".format(path))
171
+ checkpoint = _load(path)
172
+ s = checkpoint["state_dict"]
173
+ new_s = {}
174
+ for k, v in s.items():
175
+ new_s[k.replace('module.', '')] = v
176
+ model.load_state_dict(new_s)
177
+
178
+ model = model.to(device)
179
+ return model.eval()
180
+
181
+ def main():
182
+ if not os.path.isfile(args.face):
183
+ raise ValueError('--face argument must be a valid path to video/image file')
184
+
185
+ elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
186
+ full_frames = [cv2.imread(args.face)]
187
+ fps = args.fps
188
+
189
+ else:
190
+ video_stream = cv2.VideoCapture(args.face)
191
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
192
+
193
+ print('Reading video frames...')
194
+
195
+ full_frames = []
196
+ while 1:
197
+ still_reading, frame = video_stream.read()
198
+ if not still_reading:
199
+ video_stream.release()
200
+ break
201
+ if args.resize_factor > 1:
202
+ frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
203
+
204
+ if args.rotate:
205
+ frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
206
+
207
+ y1, y2, x1, x2 = args.crop
208
+ if x2 == -1: x2 = frame.shape[1]
209
+ if y2 == -1: y2 = frame.shape[0]
210
+
211
+ frame = frame[y1:y2, x1:x2]
212
+
213
+ full_frames.append(frame)
214
+
215
+ print ("Number of frames available for inference: "+str(len(full_frames)))
216
+
217
+ if not args.audio.endswith('.wav'):
218
+ print('Extracting raw audio...')
219
+ command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
220
+
221
+ subprocess.call(command, shell=True)
222
+ args.audio = 'temp/temp.wav'
223
+
224
+ wav = audio.load_wav(args.audio, 16000)
225
+ mel = audio.melspectrogram(wav)
226
+ print(mel.shape)
227
+
228
+ if np.isnan(mel.reshape(-1)).sum() > 0:
229
+ raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
230
+
231
+ mel_chunks = []
232
+ mel_idx_multiplier = 80./fps
233
+ i = 0
234
+ while 1:
235
+ start_idx = int(i * mel_idx_multiplier)
236
+ if start_idx + mel_step_size > len(mel[0]):
237
+ mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
238
+ break
239
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
240
+ i += 1
241
+
242
+ print("Length of mel chunks: {}".format(len(mel_chunks)))
243
+
244
+ full_frames = full_frames[:len(mel_chunks)]
245
+
246
+ batch_size = args.wav2lip_batch_size
247
+ gen = datagen(full_frames.copy(), mel_chunks)
248
+
249
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
250
+ total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
251
+ if i == 0:
252
+ model = load_model(args.checkpoint_path)
253
+ print ("Model loaded")
254
+
255
+ frame_h, frame_w = full_frames[0].shape[:-1]
256
+ out = cv2.VideoWriter('temp/result.avi',
257
+ cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
258
+
259
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
260
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
261
+
262
+ with torch.no_grad():
263
+ pred = model(mel_batch, img_batch)
264
+
265
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
266
+
267
+ for p, f, c in zip(pred, frames, coords):
268
+ y1, y2, x1, x2 = c
269
+ p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
270
+
271
+ f[y1:y2, x1:x2] = p
272
+ out.write(f)
273
+
274
+ out.release()
275
+
276
+ command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile)
277
+ subprocess.call(command, shell=platform.system() != 'Windows')
278
+
279
+ if __name__ == '__main__':
280
+ main()
Wav2Lip/results/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Generated results will be placed in this folder by default.