hytian2@gmail.com commited on
Commit
eae1cca
1 Parent(s): 48798aa
app.py CHANGED
@@ -35,10 +35,9 @@ available_audios = natsorted(glob.glob("./assets/audios/*.wav"))
35
  available_audios = [os.path.basename(x) for x in available_audios]
36
 
37
 
38
-
39
  with gr.Blocks() as demo:
40
  gr.HTML(
41
- """
42
  <h1 style="text-align: center; font-size: 40px; font-family: 'Times New Roman', Times, serif;">
43
  Free-View Expressive Talking Head Video Editing
44
  </h1>
@@ -51,7 +50,8 @@ with gr.Blocks() as demo:
51
  <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg#center" alt="Duplicate Space">
52
  </a>
53
  </p>
54
- """)
 
55
  with gr.Column(elem_id="col-container"):
56
  with gr.Row():
57
  with gr.Column():
@@ -80,4 +80,4 @@ with gr.Blocks() as demo:
80
  audio_input.select(lambda x: "./assets/audios/" + x, audio_input, audio_preview_output)
81
  submit_btn.click(process, inputs, outputs)
82
 
83
- demo.queue(max_size=12).launch()
 
35
  available_audios = [os.path.basename(x) for x in available_audios]
36
 
37
 
 
38
  with gr.Blocks() as demo:
39
  gr.HTML(
40
+ """
41
  <h1 style="text-align: center; font-size: 40px; font-family: 'Times New Roman', Times, serif;">
42
  Free-View Expressive Talking Head Video Editing
43
  </h1>
 
50
  <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg#center" alt="Duplicate Space">
51
  </a>
52
  </p>
53
+ """
54
+ )
55
  with gr.Column(elem_id="col-container"):
56
  with gr.Row():
57
  with gr.Column():
 
80
  audio_input.select(lambda x: "./assets/audios/" + x, audio_input, audio_preview_output)
81
  submit_btn.click(process, inputs, outputs)
82
 
83
+ demo.queue(max_size=10).launch()
assets/coords/sample1.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5b200f395b09505d61f3efb67feaacbbd5bb358e75b476c4da083e4a7cef58af
3
- size 525
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05569fbc982413520a7f81636cba156e6f67344c1cd13f4831ccd95cbb1bf0ad
3
+ size 454
assets/coords/sample2.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3ac70dd3972f406d9e8195283d11395a7b1e2528bdbdec4a3420eeac919489c9
3
- size 909
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37790e74ae602e20aa3be2811f60cda21905cc1a88d6efc53f99ec9a73f7e1df
3
+ size 810
assets/coords/sample3.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:246e4910d5ae9937f2d692beb6d6267dcb2f09bf7b7e0bd75d373a167289cf08
3
- size 598
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20b7d61ba04d8743c5d0a26b235897df9664b5dc095d96a17be3b9cdbfb06142
3
+ size 528
assets/coords/sample4.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:497b14d4185a447327fac69602b66997dc791ff333ead12680c36e3e27d20195
3
- size 656
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79b016ed5c7934c7e667dc6999457726175b17d4ef0fb67a755c5d038f0b75ec
3
+ size 567
assets/coords/sample5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5db52100e8db15ec88b4b3e4d0229fcbdcf96f623b6f4eb8890f7915b430b914
3
+ size 655
assets/videos/sample5.mp4 ADDED
Binary file (439 kB). View file
 
attributtes_utils.py CHANGED
@@ -5,43 +5,45 @@ import sys
5
  def input_pose(pose_select="front"):
6
  step = 1
7
  if pose_select == "front":
8
- pose = [[0.0, 0.0, 0.0] for i in range(0, 10, step)]#-20 to 20
9
  elif pose_select == "left_right_shaking":
10
- pose = [[-i, 0.0, 0.0] for i in range(0, 20, step)]#0 to -20
11
  pose += [[i - 20.0, 0.0, 0.0] for i in range(0, 40, step)] # -20 to 20
12
  pose += [[20.0 - i, 0.0, 0.0] for i in range(0, 20, step)] # 20 to 0
13
  pose = pose + pose
14
  pose = pose + pose
15
  pose = pose + pose
16
- # pose = pose + pose[::-1]
17
  else:
18
  raise ValueError("pose_select Error")
19
 
20
  return pose
21
 
22
 
23
- EMOTIONS = ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral']
 
 
24
  def input_emotion(emotion_select="neutral"):
25
  sacle_factor = 2
26
  if emotion_select == "neutral":
27
- emotion = [[0.0,0.0,0.0,0.0,0.0,0.0,1.0] for _ in range(2)]#((i%50))*0.04
28
  elif emotion_select == "happy":
29
- emotion = [[0.0,0.0,0.0,1.0,0.0,0.0,0.0] for _ in range(2)]#((i%50))*0.04
30
  elif emotion_select == "angry":
31
- emotion = [[1.0,0.0,0.0,0.0,0.0,0.0,0.0] for _ in range(2)]
32
  elif emotion_select == "surprised":
33
- emotion = [[0.0,0.0,0.0,0.0,0.0,1.0,0.0] for _ in range(2)]
34
  else:
35
  raise ValueError("emotion_select Error")
36
 
37
- return emotion * sacle_factor
38
 
39
 
40
  def input_blink(blink_select="yes"):
41
  if blink_select == "yes":
42
- blink = [[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [0.8], [0.6], [0.0], [0.0], [1.0]]
 
 
43
  blink = blink + blink + blink
44
  else:
45
  blink = [[1.0] for _ in range(2)]
46
  return blink
47
-
 
5
  def input_pose(pose_select="front"):
6
  step = 1
7
  if pose_select == "front":
8
+ pose = [[0.0, 0.0, 0.0] for i in range(0, 10, step)] # -20 to 20
9
  elif pose_select == "left_right_shaking":
10
+ pose = [[-i, 0.0, 0.0] for i in range(0, 20, step)] # 0 to -20
11
  pose += [[i - 20.0, 0.0, 0.0] for i in range(0, 40, step)] # -20 to 20
12
  pose += [[20.0 - i, 0.0, 0.0] for i in range(0, 20, step)] # 20 to 0
13
  pose = pose + pose
14
  pose = pose + pose
15
  pose = pose + pose
 
16
  else:
17
  raise ValueError("pose_select Error")
18
 
19
  return pose
20
 
21
 
22
+ EMOTIONS = ["angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"]
23
+
24
+
25
  def input_emotion(emotion_select="neutral"):
26
  sacle_factor = 2
27
  if emotion_select == "neutral":
28
+ emotion = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] for _ in range(2)] # ((i%50))*0.04
29
  elif emotion_select == "happy":
30
+ emotion = [[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] for _ in range(2)] # ((i%50))*0.04
31
  elif emotion_select == "angry":
32
+ emotion = [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] for _ in range(2)]
33
  elif emotion_select == "surprised":
34
+ emotion = [[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0] for _ in range(2)]
35
  else:
36
  raise ValueError("emotion_select Error")
37
 
38
+ return emotion * sacle_factor
39
 
40
 
41
  def input_blink(blink_select="yes"):
42
  if blink_select == "yes":
43
+ blink = [[1.0] for _ in range(25)]
44
+ blink += [[0.8], [0.6], [0.0], [0.0]]
45
+ blink += [[1.0] for _ in range(5)]
46
  blink = blink + blink + blink
47
  else:
48
  blink = [[1.0] for _ in range(2)]
49
  return blink
 
fete_model.py CHANGED
@@ -206,11 +206,11 @@ class FETE_model(nn.Module):
206
  # face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
207
  # print(audio_sequences.size(), face_sequences.size(), pose_sequences.size(), emotion_sequences.size())
208
 
209
- audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
210
- pose_embedding = self.pose_encoder(pose_sequences) # B, 512, 1, 1
211
  emotion_embedding = self.emotion_encoder(emotion_sequences) # B, 512, 1, 1
212
- blink_embedding = self.blink_encoder(blink_sequences) # B, 512, 1, 1
213
- inputs_embedding = torch.cat((audio_embedding, pose_embedding, emotion_embedding, blink_embedding), dim=1) # B, 1536, 1, 1
214
  # print(audio_embedding.size(), pose_embedding.size(), emotion_embedding.size(), inputs_embedding.size())
215
 
216
  feats = []
@@ -261,10 +261,10 @@ class Self_Attention(nn.Module):
261
  """
262
  super(Self_Attention, self).__init__()
263
  self.query_conv = nn.Conv2d(in_channels=in_planes_s, out_channels=in_planes_s // 8, kernel_size=1)
264
- self.key_conv = nn.Conv2d(in_channels=in_planes_r, out_channels=in_planes_r // 8, kernel_size=1)
265
  self.value_conv = nn.Conv2d(in_channels=in_planes_r, out_channels=in_planes_r, kernel_size=1)
266
- self.gamma = nn.Parameter(torch.zeros(1))
267
- self.softmax = nn.Softmax(dim=-1)
268
 
269
  def forward(self, source):
270
  source = source.float() if isinstance(source, torch.cuda.HalfTensor) else source
@@ -286,11 +286,11 @@ class Self_Attention(nn.Module):
286
  r_batchsize, rC, rH, rW = reference.size()
287
 
288
  proj_query = self.query_conv(source).view(s_batchsize, -1, sH * sW).permute(0, 2, 1)
289
- proj_key = self.key_conv(reference).view(r_batchsize, -1, rW * rH)
290
- energy = torch.bmm(proj_query, proj_key)
291
- attention = self.softmax(energy)
292
  proj_value = self.value_conv(reference).view(r_batchsize, -1, rH * rW)
293
- out = torch.bmm(proj_value, attention.permute(0, 2, 1))
294
- out = out.view(s_batchsize, sC, sH, sW)
295
- out = self.gamma * out + source
296
  return out.half() if isinstance(source, torch.cuda.FloatTensor) else out
 
206
  # face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
207
  # print(audio_sequences.size(), face_sequences.size(), pose_sequences.size(), emotion_sequences.size())
208
 
209
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
210
+ pose_embedding = self.pose_encoder(pose_sequences) # B, 512, 1, 1
211
  emotion_embedding = self.emotion_encoder(emotion_sequences) # B, 512, 1, 1
212
+ blink_embedding = self.blink_encoder(blink_sequences) # B, 512, 1, 1
213
+ inputs_embedding = torch.cat((audio_embedding, pose_embedding, emotion_embedding, blink_embedding), dim=1) # B, 1536, 1, 1
214
  # print(audio_embedding.size(), pose_embedding.size(), emotion_embedding.size(), inputs_embedding.size())
215
 
216
  feats = []
 
261
  """
262
  super(Self_Attention, self).__init__()
263
  self.query_conv = nn.Conv2d(in_channels=in_planes_s, out_channels=in_planes_s // 8, kernel_size=1)
264
+ self.key_conv = nn.Conv2d(in_channels=in_planes_r, out_channels=in_planes_r // 8, kernel_size=1)
265
  self.value_conv = nn.Conv2d(in_channels=in_planes_r, out_channels=in_planes_r, kernel_size=1)
266
+ self.gamma = nn.Parameter(torch.zeros(1))
267
+ self.softmax = nn.Softmax(dim=-1)
268
 
269
  def forward(self, source):
270
  source = source.float() if isinstance(source, torch.cuda.HalfTensor) else source
 
286
  r_batchsize, rC, rH, rW = reference.size()
287
 
288
  proj_query = self.query_conv(source).view(s_batchsize, -1, sH * sW).permute(0, 2, 1)
289
+ proj_key = self.key_conv(reference).view(r_batchsize, -1, rW * rH)
290
+ energy = torch.bmm(proj_query, proj_key)
291
+ attention = self.softmax(energy)
292
  proj_value = self.value_conv(reference).view(r_batchsize, -1, rH * rW)
293
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
294
+ out = out.view(s_batchsize, sC, sH, sW)
295
+ out = self.gamma * out + source
296
  return out.half() if isinstance(source, torch.cuda.FloatTensor) else out
inference_util.py CHANGED
@@ -280,7 +280,11 @@ def infenrece(model, face_path, audio_path, pose, emotion, blink, preview=False)
280
  else:
281
  outfile = "/tmp/{}.mp4".format(timestamp)
282
  tmp_video = "/tmp/temp_{}.mp4".format(timestamp)
283
- writer = imageio.get_writer(tmp_video, fps=fps, codec="libx264", quality=10, pixelformat="yuv420p", macro_block_size=1) if not preview else None
 
 
 
 
284
  # print('Generating frames...', outfile, steps)
285
  for inputs, frames, coords in tqdm(gen, total=steps):
286
  with torch.no_grad():
 
280
  else:
281
  outfile = "/tmp/{}.mp4".format(timestamp)
282
  tmp_video = "/tmp/temp_{}.mp4".format(timestamp)
283
+ writer = (
284
+ imageio.get_writer(tmp_video, fps=fps, codec="libx264", quality=10, pixelformat="yuv420p", macro_block_size=1)
285
+ if not preview
286
+ else None
287
+ )
288
  # print('Generating frames...', outfile, steps)
289
  for inputs, frames, coords in tqdm(gen, total=steps):
290
  with torch.no_grad():
preprocess_videos.py CHANGED
@@ -9,6 +9,7 @@ from natsort import natsorted
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
 
12
  def get_squre_coords(coords, image, size=None, last_size=None):
13
  y1, y2, x1, x2 = coords
14
  w, h = x2 - x1, y2 - y1
@@ -63,7 +64,7 @@ def face_detect(images, pads):
63
  x1 = max(0, rect[0] - padx1)
64
  x2 = min(image.shape[1], rect[2] + padx2)
65
  # y_gap, x_gap = ((y2 - y1) * 2) // 3, ((x2 - x1) * 2) // 3
66
- y_gap, x_gap = (y2 - y1)//2, (x2 - x1)//2
67
  coords_ = [y1 - y_gap, y2 + y_gap, x1 - x_gap, x2 + x_gap]
68
 
69
  _, coords = get_squre_coords(coords_, image)
@@ -79,18 +80,20 @@ def face_detect(images, pads):
79
  print("Number of frames cropped: {}".format(len(results)))
80
  print("First coords: {}".format(results[0]))
81
  boxes = np.array(results)
82
- boxes = get_smoothened_boxes(boxes, T=15)
83
  # results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
84
 
85
  del detector
86
  return boxes
87
 
 
88
  def add_black(imgs):
89
  for i in range(len(imgs)):
90
  imgs[i] = cv2.vconcat([np.zeros((100, imgs[i].shape[1], 3), dtype=np.uint8), imgs[i], np.zeros((20, imgs[i].shape[1], 3), dtype=np.uint8)])
91
 
92
  return imgs
93
 
 
94
  def preprocess(video_dir="./assets/videos", save_dir="./assets/coords"):
95
  all_videos = natsorted(glob.glob(os.path.join(video_dir, "*.mp4")))
96
  for video_path in all_videos:
@@ -115,5 +118,6 @@ def load_from_npz(video_name, save_dir="./assets/coords"):
115
  npz = np.load(os.path.join(save_dir, video_name + ".npz"))
116
  return npz["coords"]
117
 
 
118
  if __name__ == "__main__":
119
- preprocess()
 
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+
13
  def get_squre_coords(coords, image, size=None, last_size=None):
14
  y1, y2, x1, x2 = coords
15
  w, h = x2 - x1, y2 - y1
 
64
  x1 = max(0, rect[0] - padx1)
65
  x2 = min(image.shape[1], rect[2] + padx2)
66
  # y_gap, x_gap = ((y2 - y1) * 2) // 3, ((x2 - x1) * 2) // 3
67
+ y_gap, x_gap = (y2 - y1) // 2, (x2 - x1) // 2
68
  coords_ = [y1 - y_gap, y2 + y_gap, x1 - x_gap, x2 + x_gap]
69
 
70
  _, coords = get_squre_coords(coords_, image)
 
80
  print("Number of frames cropped: {}".format(len(results)))
81
  print("First coords: {}".format(results[0]))
82
  boxes = np.array(results)
83
+ boxes = get_smoothened_boxes(boxes, T=25)
84
  # results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
85
 
86
  del detector
87
  return boxes
88
 
89
+
90
  def add_black(imgs):
91
  for i in range(len(imgs)):
92
  imgs[i] = cv2.vconcat([np.zeros((100, imgs[i].shape[1], 3), dtype=np.uint8), imgs[i], np.zeros((20, imgs[i].shape[1], 3), dtype=np.uint8)])
93
 
94
  return imgs
95
 
96
+
97
  def preprocess(video_dir="./assets/videos", save_dir="./assets/coords"):
98
  all_videos = natsorted(glob.glob(os.path.join(video_dir, "*.mp4")))
99
  for video_path in all_videos:
 
118
  npz = np.load(os.path.join(save_dir, video_name + ".npz"))
119
  return npz["coords"]
120
 
121
+
122
  if __name__ == "__main__":
123
+ preprocess()