lorneluo commited on
Commit
abe5389
1 Parent(s): 7dd4b4a

update requirements

Browse files
Files changed (2) hide show
  1. requirements.txt +5 -2
  2. wav2lip/inference.py +211 -201
requirements.txt CHANGED
@@ -15,6 +15,9 @@ scipy
15
  tb-nightly
16
  yapf
17
  realesrgan
18
-
19
  ffmpeg
20
- gradio
 
 
 
 
 
15
  tb-nightly
16
  yapf
17
  realesrgan
 
18
  ffmpeg
19
+ gradio==4.1.2
20
+ ffmpy
21
+ flask_ngrok2
22
+ flask_ngrok
23
+ opencv-python
wav2lip/inference.py CHANGED
@@ -10,271 +10,281 @@ import platform
10
 
11
  parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
12
 
13
- parser.add_argument('--checkpoint_path', type=str,
14
- help='Name of saved checkpoint to load weights from', required=True)
15
-
16
- parser.add_argument('--face', type=str,
17
- help='Filepath of video/image that contains faces to use', required=True)
18
- parser.add_argument('--audio', type=str,
19
- help='Filepath of video/audio file to use as raw audio source', required=True)
20
- parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
21
- default='results/result_voice.mp4')
22
-
23
- parser.add_argument('--static', type=bool,
24
- help='If True, then use only first video frame for inference', default=False)
25
- parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
26
- default=25., required=False)
27
-
28
- parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
29
- help='Padding (top, bottom, left, right). Please adjust to include chin at least')
30
-
31
- parser.add_argument('--face_det_batch_size', type=int,
32
- help='Batch size for face detection', default=16)
33
  parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
34
 
35
- parser.add_argument('--resize_factor', default=1, type=int,
36
- help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
37
 
38
- parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
39
- help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
40
- 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
41
 
42
- parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
43
- help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
44
- 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
45
 
46
  parser.add_argument('--rotate', default=False, action='store_true',
47
- help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
48
- 'Use if you get a flipped result, despite feeding a normal looking video')
49
 
50
  parser.add_argument('--nosmooth', default=False, action='store_true',
51
- help='Prevent smoothing face detections over a short temporal window')
52
 
53
  args = parser.parse_args()
54
  args.img_size = 96
55
 
56
  if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
57
- args.static = True
 
58
 
59
  def get_smoothened_boxes(boxes, T):
60
- for i in range(len(boxes)):
61
- if i + T > len(boxes):
62
- window = boxes[len(boxes) - T:]
63
- else:
64
- window = boxes[i : i + T]
65
- boxes[i] = np.mean(window, axis=0)
66
- return boxes
 
67
 
68
  def face_detect(images):
69
- detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
70
- flip_input=False, device=device)
71
-
72
- batch_size = args.face_det_batch_size
73
-
74
- while 1:
75
- predictions = []
76
- try:
77
- for i in tqdm(range(0, len(images), batch_size)):
78
- predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
79
- except RuntimeError:
80
- if batch_size == 1:
81
- raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
82
- batch_size //= 2
83
- print('Recovering from OOM error; New batch size: {}'.format(batch_size))
84
- continue
85
- break
86
-
87
- results = []
88
- pady1, pady2, padx1, padx2 = args.pads
89
- for rect, image in zip(predictions, images):
90
- if rect is None:
91
- cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
92
- raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
93
-
94
- y1 = max(0, rect[1] - pady1)
95
- y2 = min(image.shape[0], rect[3] + pady2)
96
- x1 = max(0, rect[0] - padx1)
97
- x2 = min(image.shape[1], rect[2] + padx2)
98
-
99
- results.append([x1, y1, x2, y2])
100
-
101
- boxes = np.array(results)
102
- if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
103
- results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
104
-
105
- del detector
106
- return results
 
 
107
 
108
  def datagen(frames, mels):
109
- img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
110
 
111
- if args.box[0] == -1:
112
- if not args.static:
113
- face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
114
- else:
115
- face_det_results = face_detect([frames[0]])
116
- else:
117
- print('Using the specified bounding box instead of face detection...')
118
- y1, y2, x1, x2 = args.box
119
- face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
120
 
121
- for i, m in enumerate(mels):
122
- idx = 0 if args.static else i%len(frames)
123
- frame_to_save = frames[idx].copy()
124
- face, coords = face_det_results[idx].copy()
125
 
126
- face = cv2.resize(face, (args.img_size, args.img_size))
127
-
128
- img_batch.append(face)
129
- mel_batch.append(m)
130
- frame_batch.append(frame_to_save)
131
- coords_batch.append(coords)
132
 
133
- if len(img_batch) >= args.wav2lip_batch_size:
134
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
 
 
135
 
136
- img_masked = img_batch.copy()
137
- img_masked[:, args.img_size//2:] = 0
138
 
139
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
140
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
141
 
142
- yield img_batch, mel_batch, frame_batch, coords_batch
143
- img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
144
 
145
- if len(img_batch) > 0:
146
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
147
 
148
- img_masked = img_batch.copy()
149
- img_masked[:, args.img_size//2:] = 0
150
 
151
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
152
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
 
 
 
 
 
153
 
154
- yield img_batch, mel_batch, frame_batch, coords_batch
155
 
156
  mel_step_size = 16
157
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
158
  print('Using {} for inference.'.format(device))
159
 
 
160
  def _load(checkpoint_path):
161
- if device == 'cuda':
162
- checkpoint = torch.load(checkpoint_path)
163
- else:
164
- checkpoint = torch.load(checkpoint_path,
165
- map_location=lambda storage, loc: storage)
166
- return checkpoint
 
167
 
168
  def load_model(path):
169
- model = Wav2Lip()
170
- print("Load checkpoint from: {}".format(path))
171
- checkpoint = _load(path)
172
- s = checkpoint["state_dict"]
173
- new_s = {}
174
- for k, v in s.items():
175
- new_s[k.replace('module.', '')] = v
176
- model.load_state_dict(new_s)
177
-
178
- model = model.to(device)
179
- return model.eval()
 
180
 
181
  def main():
182
- if not os.path.isfile(args.face):
183
- raise ValueError('--face argument must be a valid path to video/image file')
 
 
 
 
 
 
 
 
184
 
185
- elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
186
- full_frames = [cv2.imread(args.face)]
187
- fps = args.fps
188
 
189
- else:
190
- video_stream = cv2.VideoCapture(args.face)
191
- fps = video_stream.get(cv2.CAP_PROP_FPS)
 
 
 
 
 
192
 
193
- print('Reading video frames...')
 
194
 
195
- full_frames = []
196
- while 1:
197
- still_reading, frame = video_stream.read()
198
- if not still_reading:
199
- video_stream.release()
200
- break
201
- if args.resize_factor > 1:
202
- frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
203
 
204
- if args.rotate:
205
- frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
206
 
207
- y1, y2, x1, x2 = args.crop
208
- if x2 == -1: x2 = frame.shape[1]
209
- if y2 == -1: y2 = frame.shape[0]
210
 
211
- frame = frame[y1:y2, x1:x2]
212
 
213
- full_frames.append(frame)
 
 
214
 
215
- print ("Number of frames available for inference: "+str(len(full_frames)))
 
216
 
217
- if not args.audio.endswith('.wav'):
218
- print('Extracting raw audio...')
219
- command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
220
 
221
- subprocess.call(command, shell=True)
222
- args.audio = 'temp/temp.wav'
223
 
224
- wav = audio.load_wav(args.audio, 16000)
225
- mel = audio.melspectrogram(wav)
226
- print(mel.shape)
 
 
 
 
 
 
 
227
 
228
- if np.isnan(mel.reshape(-1)).sum() > 0:
229
- raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
230
 
231
- mel_chunks = []
232
- mel_idx_multiplier = 80./fps
233
- i = 0
234
- while 1:
235
- start_idx = int(i * mel_idx_multiplier)
236
- if start_idx + mel_step_size > len(mel[0]):
237
- mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
238
- break
239
- mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
240
- i += 1
241
 
242
- print("Length of mel chunks: {}".format(len(mel_chunks)))
 
243
 
244
- full_frames = full_frames[:len(mel_chunks)]
 
 
 
 
 
245
 
246
- batch_size = args.wav2lip_batch_size
247
- gen = datagen(full_frames.copy(), mel_chunks)
 
248
 
249
- for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
250
- total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
251
- if i == 0:
252
- model = load_model(args.checkpoint_path)
253
- print ("Model loaded")
254
 
255
- frame_h, frame_w = full_frames[0].shape[:-1]
256
- out = cv2.VideoWriter('/tmp/result.avi',
257
- cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
258
 
259
- img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
260
- mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
261
 
262
- with torch.no_grad():
263
- pred = model(mel_batch, img_batch)
 
264
 
265
- pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
266
-
267
- for p, f, c in zip(pred, frames, coords):
268
- y1, y2, x1, x2 = c
269
- p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
270
 
271
- f[y1:y2, x1:x2] = p
272
- out.write(f)
273
 
274
- out.release()
 
275
 
276
- command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, '/tmp/result.avi', args.outfile)
277
- subprocess.call(command, shell=platform.system() != 'Windows')
278
 
279
  if __name__ == '__main__':
280
- main()
 
10
 
11
  parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
12
 
13
+ parser.add_argument('--checkpoint_path', type=str,
14
+ help='Name of saved checkpoint to load weights from', required=True)
15
+
16
+ parser.add_argument('--face', type=str,
17
+ help='Filepath of video/image that contains faces to use', required=True)
18
+ parser.add_argument('--audio', type=str,
19
+ help='Filepath of video/audio file to use as raw audio source', required=True)
20
+ parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
21
+ default='results/result_voice.mp4')
22
+
23
+ parser.add_argument('--static', type=bool,
24
+ help='If True, then use only first video frame for inference', default=False)
25
+ parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
26
+ default=25., required=False)
27
+
28
+ parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
29
+ help='Padding (top, bottom, left, right). Please adjust to include chin at least')
30
+
31
+ parser.add_argument('--face_det_batch_size', type=int,
32
+ help='Batch size for face detection', default=16)
33
  parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
34
 
35
+ parser.add_argument('--resize_factor', default=1, type=int,
36
+ help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
37
 
38
+ parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
39
+ help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
40
+ 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
41
 
42
+ parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
43
+ help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
44
+ 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
45
 
46
  parser.add_argument('--rotate', default=False, action='store_true',
47
+ help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
48
+ 'Use if you get a flipped result, despite feeding a normal looking video')
49
 
50
  parser.add_argument('--nosmooth', default=False, action='store_true',
51
+ help='Prevent smoothing face detections over a short temporal window')
52
 
53
  args = parser.parse_args()
54
  args.img_size = 96
55
 
56
  if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
57
+ args.static = True
58
+
59
 
60
  def get_smoothened_boxes(boxes, T):
61
+ for i in range(len(boxes)):
62
+ if i + T > len(boxes):
63
+ window = boxes[len(boxes) - T:]
64
+ else:
65
+ window = boxes[i: i + T]
66
+ boxes[i] = np.mean(window, axis=0)
67
+ return boxes
68
+
69
 
70
  def face_detect(images):
71
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
72
+ flip_input=False, device=device)
73
+
74
+ batch_size = args.face_det_batch_size
75
+
76
+ while 1:
77
+ predictions = []
78
+ try:
79
+ for i in tqdm(range(0, len(images), batch_size)):
80
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
81
+ except RuntimeError:
82
+ if batch_size == 1:
83
+ raise RuntimeError(
84
+ 'Image too big to run face detection on GPU. Please use the --resize_factor argument')
85
+ batch_size //= 2
86
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
87
+ continue
88
+ break
89
+
90
+ results = []
91
+ pady1, pady2, padx1, padx2 = args.pads
92
+ for rect, image in zip(predictions, images):
93
+ if rect is None:
94
+ cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
95
+ raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
96
+
97
+ y1 = max(0, rect[1] - pady1)
98
+ y2 = min(image.shape[0], rect[3] + pady2)
99
+ x1 = max(0, rect[0] - padx1)
100
+ x2 = min(image.shape[1], rect[2] + padx2)
101
+
102
+ results.append([x1, y1, x2, y2])
103
+
104
+ boxes = np.array(results)
105
+ if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
106
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
107
+
108
+ del detector
109
+ return results
110
+
111
 
112
  def datagen(frames, mels):
113
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
114
 
115
+ if args.box[0] == -1:
116
+ if not args.static:
117
+ face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
118
+ else:
119
+ face_det_results = face_detect([frames[0]])
120
+ else:
121
+ print('Using the specified bounding box instead of face detection...')
122
+ y1, y2, x1, x2 = args.box
123
+ face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
124
 
125
+ for i, m in enumerate(mels):
126
+ idx = 0 if args.static else i % len(frames)
127
+ frame_to_save = frames[idx].copy()
128
+ face, coords = face_det_results[idx].copy()
129
 
130
+ face = cv2.resize(face, (args.img_size, args.img_size))
 
 
 
 
 
131
 
132
+ img_batch.append(face)
133
+ mel_batch.append(m)
134
+ frame_batch.append(frame_to_save)
135
+ coords_batch.append(coords)
136
 
137
+ if len(img_batch) >= args.wav2lip_batch_size:
138
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
139
 
140
+ img_masked = img_batch.copy()
141
+ img_masked[:, args.img_size // 2:] = 0
142
 
143
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
144
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
145
 
146
+ yield img_batch, mel_batch, frame_batch, coords_batch
147
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
148
 
149
+ if len(img_batch) > 0:
150
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
151
 
152
+ img_masked = img_batch.copy()
153
+ img_masked[:, args.img_size // 2:] = 0
154
+
155
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
156
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
157
+
158
+ yield img_batch, mel_batch, frame_batch, coords_batch
159
 
 
160
 
161
  mel_step_size = 16
162
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
163
  print('Using {} for inference.'.format(device))
164
 
165
+
166
  def _load(checkpoint_path):
167
+ if device == 'cuda':
168
+ checkpoint = torch.load(checkpoint_path)
169
+ else:
170
+ checkpoint = torch.load(checkpoint_path,
171
+ map_location=lambda storage, loc: storage)
172
+ return checkpoint
173
+
174
 
175
  def load_model(path):
176
+ model = Wav2Lip()
177
+ print("Load checkpoint from: {}".format(path))
178
+ checkpoint = _load(path)
179
+ s = checkpoint["state_dict"]
180
+ new_s = {}
181
+ for k, v in s.items():
182
+ new_s[k.replace('module.', '')] = v
183
+ model.load_state_dict(new_s)
184
+
185
+ model = model.to(device)
186
+ return model.eval()
187
+
188
 
189
  def main():
190
+ if not os.path.isfile(args.face):
191
+ raise ValueError('--face argument must be a valid path to video/image file')
192
+
193
+ elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
194
+ full_frames = [cv2.imread(args.face)]
195
+ fps = args.fps
196
+
197
+ else:
198
+ video_stream = cv2.VideoCapture(args.face)
199
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
200
 
201
+ print('Reading video frames...')
 
 
202
 
203
+ full_frames = []
204
+ while 1:
205
+ still_reading, frame = video_stream.read()
206
+ if not still_reading:
207
+ video_stream.release()
208
+ break
209
+ if args.resize_factor > 1:
210
+ frame = cv2.resize(frame, (frame.shape[1] // args.resize_factor, frame.shape[0] // args.resize_factor))
211
 
212
+ if args.rotate:
213
+ frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
214
 
215
+ y1, y2, x1, x2 = args.crop
216
+ if x2 == -1: x2 = frame.shape[1]
217
+ if y2 == -1: y2 = frame.shape[0]
 
 
 
 
 
218
 
219
+ frame = frame[y1:y2, x1:x2]
 
220
 
221
+ full_frames.append(frame)
 
 
222
 
223
+ print("Number of frames available for inference: " + str(len(full_frames)))
224
 
225
+ if not args.audio.endswith('.wav'):
226
+ print('Extracting raw audio...')
227
+ command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
228
 
229
+ subprocess.call(command, shell=True)
230
+ args.audio = 'temp/temp.wav'
231
 
232
+ wav = audio.load_wav(args.audio, 16000)
233
+ mel = audio.melspectrogram(wav)
234
+ print(mel.shape)
235
 
236
+ if np.isnan(mel.reshape(-1)).sum() > 0:
237
+ raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
238
 
239
+ mel_chunks = []
240
+ mel_idx_multiplier = 80. / fps
241
+ i = 0
242
+ while 1:
243
+ start_idx = int(i * mel_idx_multiplier)
244
+ if start_idx + mel_step_size > len(mel[0]):
245
+ mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
246
+ break
247
+ mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
248
+ i += 1
249
 
250
+ print("Length of mel chunks: {}".format(len(mel_chunks)))
 
251
 
252
+ full_frames = full_frames[:len(mel_chunks)]
 
 
 
 
 
 
 
 
 
253
 
254
+ batch_size = args.wav2lip_batch_size
255
+ gen = datagen(full_frames.copy(), mel_chunks)
256
 
257
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
258
+ total=int(
259
+ np.ceil(float(len(mel_chunks)) / batch_size)))):
260
+ if i == 0:
261
+ model = load_model(args.checkpoint_path)
262
+ print("Model loaded")
263
 
264
+ frame_h, frame_w = full_frames[0].shape[:-1]
265
+ out = cv2.VideoWriter('/tmp/result.avi',
266
+ cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
267
 
268
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
269
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
 
 
 
270
 
271
+ with torch.no_grad():
272
+ pred = model(mel_batch, img_batch)
 
273
 
274
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
 
275
 
276
+ for p, f, c in zip(pred, frames, coords):
277
+ y1, y2, x1, x2 = c
278
+ p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
279
 
280
+ f[y1:y2, x1:x2] = p
281
+ out.write(f)
 
 
 
282
 
283
+ out.release()
 
284
 
285
+ command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, '/tmp/result.avi', args.outfile)
286
+ subprocess.call(command, shell=platform.system() != 'Windows')
287
 
 
 
288
 
289
  if __name__ == '__main__':
290
+ main()