harlanhong commited on
Commit
e418082
·
1 Parent(s): c45e94d
Files changed (7) hide show
  1. .gitignore +1 -0
  2. app.py +13 -106
  3. demo_dagan.py +92 -82
  4. depth.pth +0 -3
  5. encoder.pth +0 -3
  6. generator.pt +0 -3
  7. kp_detector.pt +0 -3
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pyc
app.py CHANGED
@@ -3,21 +3,14 @@ import shutil
3
  import gradio as gr
4
  from PIL import Image
5
  import subprocess
 
6
  #os.chdir('Restormer')
7
- from demo_dagan import *
8
  # Download sample images
9
- import torch
10
- import torch.nn.functional as F
11
- import os
12
- from skimage import img_as_ubyte
13
- import imageio
14
- from skimage.transform import resize
15
- import numpy as np
16
- import modules.generator as G
17
- import modules.keypoint_detector as KPD
18
- import yaml
19
- from collections import OrderedDict
20
- import depth
21
 
22
  examples = [['project/cartoon2.jpg','project/video1.mp4'],
23
  ['project/cartoon3.jpg','project/video2.mp4'],
@@ -25,9 +18,6 @@ examples = [['project/cartoon2.jpg','project/video1.mp4'],
25
  ['project/celeb2.jpg','project/video2.mp4'],
26
  ]
27
 
28
-
29
- inference_on = ['Full Resolution Image', 'Downsampled Image']
30
-
31
  title = "DaGAN"
32
  description = """
33
  Gradio demo for <b>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</b>, CVPR 2022L. <a href='https://arxiv.org/abs/2203.06605'>[Paper]</a><a href='https://github.com/harlanhong/CVPR2022-DaGAN'>[Github Code]</a>\n
@@ -38,99 +28,16 @@ Gradio demo for <b>Depth-Aware Generative Adversarial Network for Talking Head V
38
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.06605'>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</a> | <a href='https://github.com/harlanhong/CVPR2022-DaGAN'>Github Repo</a></p>"
39
 
40
 
41
- def inference(source_image, video):
42
  if not os.path.exists('temp'):
43
- os.system('mkdir temp')
44
- cmd = f"ffmpeg -y -ss 00:00:00 -i {video} -to 00:00:08 -c copy video_input.mp4"
 
45
  subprocess.run(cmd.split())
46
  driving_video = "video_input.mp4"
47
- output = "rst.mp4"
48
- with open("config/vox-adv-256.yaml") as f:
49
- config = yaml.load(f)
50
- generator = G.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])
51
- config['model_params']['common_params']['num_channels'] = 4
52
- kp_detector = KPD.KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])
53
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
54
-
55
-
56
- g_checkpoint = torch.load("generator.pt", map_location=device)
57
- kp_checkpoint = torch.load("kp_detector.pt", map_location=device)
58
-
59
- ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
60
- generator.load_state_dict(ckp_generator)
61
- ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in kp_checkpoint.items())
62
- kp_detector.load_state_dict(ckp_kp_detector)
63
-
64
- depth_encoder = depth.ResnetEncoder(18, False)
65
- depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
66
- loaded_dict_enc = torch.load('encoder.pth')
67
- loaded_dict_dec = torch.load('depth.pth')
68
- filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
69
- depth_encoder.load_state_dict(filtered_dict_enc)
70
- ckp_depth_decoder= {k: v for k, v in loaded_dict_dec.items() if k in depth_decoder.state_dict()}
71
- depth_decoder.load_state_dict(ckp_depth_decoder)
72
- depth_encoder.eval()
73
- depth_decoder.eval()
74
-
75
- # device = torch.device('cpu')
76
- # stx()
77
-
78
- generator = generator.to(device)
79
- kp_detector = kp_detector.to(device)
80
- depth_encoder = depth_encoder.to(device)
81
- depth_decoder = depth_decoder.to(device)
82
-
83
- generator.eval()
84
- kp_detector.eval()
85
- depth_encoder.eval()
86
- depth_decoder.eval()
87
-
88
- img_multiple_of = 8
89
-
90
- with torch.inference_mode():
91
- if torch.cuda.is_available():
92
- torch.cuda.ipc_collect()
93
- torch.cuda.empty_cache()
94
- source_image = imageio.imread(source_image)
95
- reader = imageio.get_reader(driving_video)
96
- fps = reader.get_meta_data()['fps']
97
- driving_video = []
98
- try:
99
- for im in reader:
100
- driving_video.append(im)
101
- except RuntimeError:
102
- pass
103
- reader.close()
104
-
105
- source_image = resize(source_image, (256, 256))[..., :3]
106
- driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
107
-
108
-
109
-
110
- i = find_best_frame(source_image, driving_video)
111
- print ("Best frame: " + str(i))
112
- driving_forward = driving_video[i:]
113
- driving_backward = driving_video[:(i+1)][::-1]
114
- sources_forward, drivings_forward, predictions_forward,depth_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
115
- sources_backward, drivings_backward, predictions_backward,depth_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
116
- predictions = predictions_backward[::-1] + predictions_forward[1:]
117
- sources = sources_backward[::-1] + sources_forward[1:]
118
- drivings = drivings_backward[::-1] + drivings_forward[1:]
119
- depth_gray = depth_backward[::-1] + depth_forward[1:]
120
-
121
- imageio.mimsave(output, [np.concatenate((img_as_ubyte(s),img_as_ubyte(d),img_as_ubyte(p)),1) for (s,d,p) in zip(sources, drivings, predictions)], fps=fps)
122
- imageio.mimsave("gray.mp4", depth_gray, fps=fps)
123
- # merge the gray video
124
- animation = np.array(imageio.mimread(output,memtest=False))
125
- gray = np.array(imageio.mimread("gray.mp4",memtest=False))
126
-
127
- src_dst = animation[:,:,:512,:]
128
- animate = animation[:,:,512:,:]
129
- merge = np.concatenate((src_dst,gray,animate),2)
130
- imageio.mimsave(output, merge, fps=fps)
131
-
132
- return output
133
-
134
  gr.Interface(
135
  inference,
136
  [
 
3
  import gradio as gr
4
  from PIL import Image
5
  import subprocess
6
+
7
  #os.chdir('Restormer')
8
+
9
  # Download sample images
10
+ os.system("wget https://github.com/swz30/Restormer/releases/download/v1.0/sample_images.zip")
11
+ shutil.unpack_archive('sample_images.zip')
12
+ os.remove('sample_images.zip')
13
+
 
 
 
 
 
 
 
 
14
 
15
  examples = [['project/cartoon2.jpg','project/video1.mp4'],
16
  ['project/cartoon3.jpg','project/video2.mp4'],
 
18
  ['project/celeb2.jpg','project/video2.mp4'],
19
  ]
20
 
 
 
 
21
  title = "DaGAN"
22
  description = """
23
  Gradio demo for <b>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</b>, CVPR 2022L. <a href='https://arxiv.org/abs/2203.06605'>[Paper]</a><a href='https://github.com/harlanhong/CVPR2022-DaGAN'>[Github Code]</a>\n
 
28
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.06605'>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</a> | <a href='https://github.com/harlanhong/CVPR2022-DaGAN'>Github Repo</a></p>"
29
 
30
 
31
+ def inference(img, video):
32
  if not os.path.exists('temp'):
33
+ os.system('mkdir temp')
34
+ #### Resize the longer edge of the input image
35
+ cmd = f"ffmpeg -y -ss 00:00:00 -i {video} -to 00:00:08 -c copy temp/driving_video.mp4"
36
  subprocess.run(cmd.split())
37
  driving_video = "video_input.mp4"
38
+ os.system("python demo_dagan.py --source_image {} --driving_video 'temp/driving_video.mp4' --output 'temp/rst.mp4'".format(img))
39
+ return f'temp/rst.mp4'
40
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  gr.Interface(
42
  inference,
43
  [
demo_dagan.py CHANGED
@@ -6,10 +6,19 @@
6
  import torch
7
  import torch.nn.functional as F
8
  import os
 
 
9
  import argparse
 
 
10
  from scipy.spatial import ConvexHull
11
  from tqdm import tqdm
12
  import numpy as np
 
 
 
 
 
13
  parser = argparse.ArgumentParser(description='Test DaGAN on your own images')
14
  parser.add_argument('--source_image', default='./temp/source.jpg', type=str, help='Directory of input source image')
15
  parser.add_argument('--driving_video', default='./temp/driving.mp4', type=str, help='Directory for driving video')
@@ -62,6 +71,7 @@ def find_best_frame(source, driving, cpu=False):
62
  frame_num = i
63
  return frame_num
64
 
 
65
  def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False):
66
  sources = []
67
  drivings = []
@@ -111,88 +121,88 @@ def make_animation(source_image, driving_video, generator, kp_detector, relative
111
  predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
112
  depth_gray.append(gray_driving)
113
  return sources, drivings, predictions,depth_gray
114
- # with open("config/vox-adv-256.yaml") as f:
115
- # config = yaml.load(f)
116
- # generator = G.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])
117
- # config['model_params']['common_params']['num_channels'] = 4
118
- # kp_detector = KPD.KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])
119
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
120
-
121
-
122
- # g_checkpoint = torch.load("generator.pt", map_location=device)
123
- # kp_checkpoint = torch.load("kp_detector.pt", map_location=device)
124
-
125
- # ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
126
- # generator.load_state_dict(ckp_generator)
127
- # ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in kp_checkpoint.items())
128
- # kp_detector.load_state_dict(ckp_kp_detector)
129
-
130
- # depth_encoder = depth.ResnetEncoder(18, False)
131
- # depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
132
- # loaded_dict_enc = torch.load('encoder.pth')
133
- # loaded_dict_dec = torch.load('depth.pth')
134
- # filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
135
- # depth_encoder.load_state_dict(filtered_dict_enc)
136
- # ckp_depth_decoder= {k: v for k, v in loaded_dict_dec.items() if k in depth_decoder.state_dict()}
137
- # depth_decoder.load_state_dict(ckp_depth_decoder)
138
- # depth_encoder.eval()
139
- # depth_decoder.eval()
140
 
141
- # # device = torch.device('cpu')
142
- # # stx()
143
-
144
- # generator = generator.to(device)
145
- # kp_detector = kp_detector.to(device)
146
- # depth_encoder = depth_encoder.to(device)
147
- # depth_decoder = depth_decoder.to(device)
148
-
149
- # generator.eval()
150
- # kp_detector.eval()
151
- # depth_encoder.eval()
152
- # depth_decoder.eval()
153
-
154
- # img_multiple_of = 8
155
-
156
- # with torch.inference_mode():
157
- # if torch.cuda.is_available():
158
- # torch.cuda.ipc_collect()
159
- # torch.cuda.empty_cache()
160
- # source_image = imageio.imread(args.source_image)
161
- # reader = imageio.get_reader(args.driving_video)
162
- # fps = reader.get_meta_data()['fps']
163
- # driving_video = []
164
- # try:
165
- # for im in reader:
166
- # driving_video.append(im)
167
- # except RuntimeError:
168
- # pass
169
- # reader.close()
170
-
171
- # source_image = resize(source_image, (256, 256))[..., :3]
172
- # driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
173
-
174
-
175
-
176
- # i = find_best_frame(source_image, driving_video)
177
- # print ("Best frame: " + str(i))
178
- # driving_forward = driving_video[i:]
179
- # driving_backward = driving_video[:(i+1)][::-1]
180
- # sources_forward, drivings_forward, predictions_forward,depth_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
181
- # sources_backward, drivings_backward, predictions_backward,depth_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
182
- # predictions = predictions_backward[::-1] + predictions_forward[1:]
183
- # sources = sources_backward[::-1] + sources_forward[1:]
184
- # drivings = drivings_backward[::-1] + drivings_forward[1:]
185
- # depth_gray = depth_backward[::-1] + depth_forward[1:]
186
-
187
- # imageio.mimsave(args.output, [np.concatenate((img_as_ubyte(s),img_as_ubyte(d),img_as_ubyte(p)),1) for (s,d,p) in zip(sources, drivings, predictions)], fps=fps)
188
- # imageio.mimsave("gray.mp4", depth_gray, fps=fps)
189
- # # merge the gray video
190
- # animation = np.array(imageio.mimread(args.output,memtest=False))
191
- # gray = np.array(imageio.mimread("gray.mp4",memtest=False))
192
-
193
- # src_dst = animation[:,:,:512,:]
194
- # animate = animation[:,:,512:,:]
195
- # merge = np.concatenate((src_dst,gray,animate),2)
196
- # imageio.mimsave(args.output, merge, fps=fps)
197
 
198
  # print(f"\nRestored images are saved at {out_dir}")
 
6
  import torch
7
  import torch.nn.functional as F
8
  import os
9
+ from skimage import img_as_ubyte
10
+ import cv2
11
  import argparse
12
+ import imageio
13
+ from skimage.transform import resize
14
  from scipy.spatial import ConvexHull
15
  from tqdm import tqdm
16
  import numpy as np
17
+ import modules.generator as G
18
+ import modules.keypoint_detector as KPD
19
+ import yaml
20
+ from collections import OrderedDict
21
+ import depth
22
  parser = argparse.ArgumentParser(description='Test DaGAN on your own images')
23
  parser.add_argument('--source_image', default='./temp/source.jpg', type=str, help='Directory of input source image')
24
  parser.add_argument('--driving_video', default='./temp/driving.mp4', type=str, help='Directory for driving video')
 
71
  frame_num = i
72
  return frame_num
73
 
74
+
75
  def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False):
76
  sources = []
77
  drivings = []
 
121
  predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
122
  depth_gray.append(gray_driving)
123
  return sources, drivings, predictions,depth_gray
124
+ with open("config/vox-adv-256.yaml") as f:
125
+ config = yaml.load(f)
126
+ generator = G.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])
127
+ config['model_params']['common_params']['num_channels'] = 4
128
+ kp_detector = KPD.KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])
129
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
130
+
131
+
132
+ g_checkpoint = torch.load("generator.pt", map_location=device)
133
+ kp_checkpoint = torch.load("kp_detector.pt", map_location=device)
134
+
135
+ ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
136
+ generator.load_state_dict(ckp_generator)
137
+ ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in kp_checkpoint.items())
138
+ kp_detector.load_state_dict(ckp_kp_detector)
139
+
140
+ depth_encoder = depth.ResnetEncoder(18, False)
141
+ depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
142
+ loaded_dict_enc = torch.load('encoder.pth')
143
+ loaded_dict_dec = torch.load('depth.pth')
144
+ filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
145
+ depth_encoder.load_state_dict(filtered_dict_enc)
146
+ ckp_depth_decoder= {k: v for k, v in loaded_dict_dec.items() if k in depth_decoder.state_dict()}
147
+ depth_decoder.load_state_dict(ckp_depth_decoder)
148
+ depth_encoder.eval()
149
+ depth_decoder.eval()
150
 
151
+ # device = torch.device('cpu')
152
+ # stx()
153
+
154
+ generator = generator.to(device)
155
+ kp_detector = kp_detector.to(device)
156
+ depth_encoder = depth_encoder.to(device)
157
+ depth_decoder = depth_decoder.to(device)
158
+
159
+ generator.eval()
160
+ kp_detector.eval()
161
+ depth_encoder.eval()
162
+ depth_decoder.eval()
163
+
164
+ img_multiple_of = 8
165
+
166
+ with torch.inference_mode():
167
+ if torch.cuda.is_available():
168
+ torch.cuda.ipc_collect()
169
+ torch.cuda.empty_cache()
170
+ source_image = imageio.imread(args.source_image)
171
+ reader = imageio.get_reader(args.driving_video)
172
+ fps = reader.get_meta_data()['fps']
173
+ driving_video = []
174
+ try:
175
+ for im in reader:
176
+ driving_video.append(im)
177
+ except RuntimeError:
178
+ pass
179
+ reader.close()
180
+
181
+ source_image = resize(source_image, (256, 256))[..., :3]
182
+ driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
183
+
184
+
185
+
186
+ i = find_best_frame(source_image, driving_video)
187
+ print ("Best frame: " + str(i))
188
+ driving_forward = driving_video[i:]
189
+ driving_backward = driving_video[:(i+1)][::-1]
190
+ sources_forward, drivings_forward, predictions_forward,depth_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
191
+ sources_backward, drivings_backward, predictions_backward,depth_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
192
+ predictions = predictions_backward[::-1] + predictions_forward[1:]
193
+ sources = sources_backward[::-1] + sources_forward[1:]
194
+ drivings = drivings_backward[::-1] + drivings_forward[1:]
195
+ depth_gray = depth_backward[::-1] + depth_forward[1:]
196
+
197
+ imageio.mimsave(args.output, [np.concatenate((img_as_ubyte(s),img_as_ubyte(d),img_as_ubyte(p)),1) for (s,d,p) in zip(sources, drivings, predictions)], fps=fps)
198
+ imageio.mimsave("gray.mp4", depth_gray, fps=fps)
199
+ # merge the gray video
200
+ animation = np.array(imageio.mimread(args.output,memtest=False))
201
+ gray = np.array(imageio.mimread("gray.mp4",memtest=False))
202
+
203
+ src_dst = animation[:,:,:512,:]
204
+ animate = animation[:,:,512:,:]
205
+ merge = np.concatenate((src_dst,gray,animate),2)
206
+ imageio.mimsave(args.output, merge, fps=fps)
207
 
208
  # print(f"\nRestored images are saved at {out_dir}")
depth.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:11eb72a1e520d6086d9f357b6740340a235b067acdd6d495049877de2772d1a4
3
- size 12621521
 
 
 
 
encoder.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:de3d906dac888c2947cf0dabe319b8d3a5da98dd695d8b96512891f5c5a6bca3
3
- size 46837645
 
 
 
 
generator.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:34ac6a18ca3b0d9df080990d4975d9f4db04f7216fa9dbe4d580e920ee4b2bde
3
- size 270494161
 
 
 
 
kp_detector.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6f03aac403bf71445163f22cd7f883548980603065326c6b8ee08b74ad18d1bd
3
- size 57103620