cleardusk commited on
Commit
58ca92c
·
1 Parent(s): 3cf21f4

fix: FPS and image shape of retarget

Browse files
app.py CHANGED
@@ -116,6 +116,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
116
  flag_relative_input = gr.Checkbox(value=True, label="relative motion")
117
  flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
118
  flag_remap_input = gr.Checkbox(value=True, label="paste-back")
 
119
  with gr.Row():
120
  with gr.Column():
121
  process_button_animation = gr.Button("🚀 Animate", variant="primary")
 
116
  flag_relative_input = gr.Checkbox(value=True, label="relative motion")
117
  flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
118
  flag_remap_input = gr.Checkbox(value=True, label="paste-back")
119
+ gr.Markdown(load_description("assets/gradio_description_animate_clear.md"))
120
  with gr.Row():
121
  with gr.Column():
122
  process_button_animation = gr.Button("🚀 Animate", variant="primary")
assets/gradio_description_animate_clear.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ <div style="font-size: 1.2em; text-align: center;">
2
+ Step 3: click the <strong>🚀 Animate</strong> button below to generate, or click 🧹 Clear to erase the results
3
+ </div>
assets/gradio_description_upload.md CHANGED
@@ -1,8 +1,24 @@
1
- ## 🤗 This is the official gradio demo for LivePortrait.
2
- <div style="font-size: 1.0em;">
3
  If you find LivePortrait fun 🤪 or useful, please consider starring 🌟 our <a href="https://github.com/KwaiVGI/LivePortrait">GitHub Repo</a> to discover more features!
4
  </div>
 
 
 
 
 
 
 
5
  <br>
6
- <div style="font-size: 1.2em;">
7
- Please upload or use a webcam to get a <strong>Source Portrait</strong> (any aspect ratio) and upload a <strong>Driving Video</strong> (1:1 aspect ratio).
 
 
 
 
 
 
 
 
 
8
  </div>
 
1
+ <!-- ## 🤗 This is the official gradio demo for LivePortrait. -->
2
+ <!-- <div style="font-size: 1.0em;">
3
  If you find LivePortrait fun 🤪 or useful, please consider starring 🌟 our <a href="https://github.com/KwaiVGI/LivePortrait">GitHub Repo</a> to discover more features!
4
  </div>
5
+
6
+
7
+ <!-- <div style="font-size: 1.2em;">
8
+ Step1: upload or use a webcam to get a <strong>Source Portrait</strong> (any aspect ratio) to left side.<br>
9
+ Step2: upload a <strong>Driving Video</strong> (1:1 aspect ratio) to right side.
10
+ </div> -->
11
+
12
  <br>
13
+ <div style="font-size: 1.2em; display: flex; justify-content: space-between;">
14
+ <div style="flex: 1; text-align: center; margin-right: 20px;">
15
+ <div style="display: inline-block;">
16
+ Step 1: upload a <strong>Source Portrait</strong> (any aspect ratio) ⬇️
17
+ </div>
18
+ </div>
19
+ <div style="flex: 1; text-align: center; margin-left: 20px;">
20
+ <div style="display: inline-block;">
21
+ Step 2: upload a <strong>Driving Video</strong> (1:1 aspect ratio) ⬇️
22
+ </div>
23
+ </div>
24
  </div>
assets/gradio_title.md CHANGED
@@ -9,6 +9,9 @@
9
  <a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
10
  &nbsp;
11
  <a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
 
 
 
12
  </div>
13
  </div>
14
  </div>
 
9
  <a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
10
  &nbsp;
11
  <a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
12
+ &nbsp;
13
+ <a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/github/stars/KwaiVGI/LivePortrait
14
+ "></a>
15
  </div>
16
  </div>
17
  </div>
src/gradio_pipeline.py CHANGED
@@ -9,7 +9,7 @@ from .live_portrait_pipeline import LivePortraitPipeline
9
  from .utils.io import load_img_online
10
  from .utils.rprint import rlog as log
11
  from .utils.crop import prepare_paste_back, paste_back
12
- from .utils.camera import get_rotation_matrix
13
 
14
  def update_args(args, user_args):
15
  """update the args according to user inputs
@@ -33,7 +33,7 @@ class GradioPipeline(LivePortraitPipeline):
33
  flag_relative_input,
34
  flag_do_crop_input,
35
  flag_remap_input,
36
- ):
37
  """ for video driven potrait animation
38
  """
39
  if input_image_path is not None and input_video_path is not None:
@@ -91,7 +91,7 @@ class GradioPipeline(LivePortraitPipeline):
91
  # gr.Info("Upload successfully!", duration=2)
92
  inference_cfg = self.live_portrait_wrapper.cfg
93
  ######## process source portrait ########
94
- img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
95
  log(f"Load source image from {input_image}.")
96
  crop_info = self.cropper.crop_single_image(img_rgb)
97
  if flag_do_crop:
@@ -99,7 +99,7 @@ class GradioPipeline(LivePortraitPipeline):
99
  else:
100
  I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
101
  x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
102
- R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
103
  ############################################
104
  f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
105
  x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
 
9
  from .utils.io import load_img_online
10
  from .utils.rprint import rlog as log
11
  from .utils.crop import prepare_paste_back, paste_back
12
+ # from .utils.camera import get_rotation_matrix
13
 
14
  def update_args(args, user_args):
15
  """update the args according to user inputs
 
33
  flag_relative_input,
34
  flag_do_crop_input,
35
  flag_remap_input,
36
+ ):
37
  """ for video driven potrait animation
38
  """
39
  if input_image_path is not None and input_video_path is not None:
 
91
  # gr.Info("Upload successfully!", duration=2)
92
  inference_cfg = self.live_portrait_wrapper.cfg
93
  ######## process source portrait ########
94
+ img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=1) # n=1 means do not trim the pixels
95
  log(f"Load source image from {input_image}.")
96
  crop_info = self.cropper.crop_single_image(img_rgb)
97
  if flag_do_crop:
 
99
  else:
100
  I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
101
  x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
102
+ # R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
103
  ############################################
104
  f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
105
  x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
src/live_portrait_pipeline.py CHANGED
@@ -18,7 +18,7 @@ from .config.inference_config import InferenceConfig
18
  from .config.crop_config import CropConfig
19
  from .utils.cropper import Cropper
20
  from .utils.camera import get_rotation_matrix
21
- from .utils.video import images2video, concat_frames
22
  from .utils.crop import _transform_img, prepare_paste_back, paste_back
23
  from .utils.retargeting_utils import calc_lip_close_ratio
24
  from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
@@ -67,8 +67,12 @@ class LivePortraitPipeline(object):
67
  ############################################
68
 
69
  ######## process driving info ########
 
70
  if is_video(args.driving_info):
71
  log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
 
 
 
72
  # TODO: 这里track一下驱动视频 -> 构建模板
73
  driving_rgb_lst = load_driving_info(args.driving_info)
74
  driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
@@ -177,13 +181,13 @@ class LivePortraitPipeline(object):
177
  frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
178
  # save (driving frames, source image, drived frames) result
179
  wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
180
- images2video(frames_concatenated, wfp=wfp_concat)
181
 
182
  # save drived result
183
  wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
184
  if inference_cfg.flag_pasteback:
185
- images2video(I_p_paste_lst, wfp=wfp)
186
  else:
187
- images2video(I_p_lst, wfp=wfp)
188
 
189
  return wfp, wfp_concat
 
18
  from .config.crop_config import CropConfig
19
  from .utils.cropper import Cropper
20
  from .utils.camera import get_rotation_matrix
21
+ from .utils.video import images2video, concat_frames, get_fps
22
  from .utils.crop import _transform_img, prepare_paste_back, paste_back
23
  from .utils.retargeting_utils import calc_lip_close_ratio
24
  from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
 
67
  ############################################
68
 
69
  ######## process driving info ########
70
+ output_fps = 30 # default fps
71
  if is_video(args.driving_info):
72
  log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
73
+ output_fps = int(get_fps(args.driving_info))
74
+ log(f'The FPS of {args.driving_info} is: {output_fps}')
75
+
76
  # TODO: 这里track一下驱动视频 -> 构建模板
77
  driving_rgb_lst = load_driving_info(args.driving_info)
78
  driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
 
181
  frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
182
  # save (driving frames, source image, drived frames) result
183
  wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
184
+ images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
185
 
186
  # save drived result
187
  wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
188
  if inference_cfg.flag_pasteback:
189
+ images2video(I_p_paste_lst, wfp=wfp, fps=output_fps)
190
  else:
191
+ images2video(I_p_lst, wfp=wfp, fps=output_fps)
192
 
193
  return wfp, wfp_concat
src/utils/video.py CHANGED
@@ -12,6 +12,7 @@ import cv2
12
 
13
  from rich.progress import track
14
  from .helper import prefix
 
15
  from .rprint import rprint as print
16
 
17
 
@@ -131,9 +132,14 @@ def change_video_fps(input_file, output_file, fps=20, codec='libx264', crf=5):
131
  exec_cmd(cmd)
132
 
133
 
134
- def get_fps(filepath):
135
- import ffmpeg
136
- probe = ffmpeg.probe(filepath)
137
- video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
138
- fps = eval(video_stream['avg_frame_rate'])
 
 
 
 
 
139
  return fps
 
12
 
13
  from rich.progress import track
14
  from .helper import prefix
15
+ from .rprint import rlog as log
16
  from .rprint import rprint as print
17
 
18
 
 
132
  exec_cmd(cmd)
133
 
134
 
135
+ def get_fps(filepath, default_fps=25):
136
+ try:
137
+ fps = cv2.VideoCapture(filepath).get(cv2.CAP_PROP_FPS)
138
+
139
+ if fps in (0, None):
140
+ fps = default_fps
141
+ except Exception as e:
142
+ log(e)
143
+ fps = default_fps
144
+
145
  return fps