fffiloni commited on
Commit
ca255d7
·
verified ·
1 Parent(s): d87127d

add gradio progress track

Browse files
Files changed (1) hide show
  1. src/gradio_pipeline.py +149 -148
src/gradio_pipeline.py CHANGED
@@ -1,148 +1,149 @@
1
- # coding: utf-8
2
-
3
- """
4
- Pipeline for gradio
5
- """
6
- import gradio as gr
7
-
8
- from .config.argument_config import ArgumentConfig
9
- from .live_portrait_pipeline import LivePortraitPipeline
10
- from .utils.io import load_img_online
11
- from .utils.rprint import rlog as log
12
- from .utils.crop import prepare_paste_back, paste_back
13
- from .utils.camera import get_rotation_matrix
14
-
15
-
16
- def update_args(args, user_args):
17
- """update the args according to user inputs
18
- """
19
- for k, v in user_args.items():
20
- if hasattr(args, k):
21
- setattr(args, k, v)
22
- return args
23
-
24
-
25
- class GradioPipeline(LivePortraitPipeline):
26
-
27
- def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
28
- super().__init__(inference_cfg, crop_cfg)
29
- # self.live_portrait_wrapper = self.live_portrait_wrapper
30
- self.args = args
31
-
32
- def execute_video(
33
- self,
34
- input_image_path,
35
- input_video_path,
36
- flag_relative_input,
37
- flag_do_crop_input,
38
- flag_remap_input,
39
- flag_crop_driving_video_input
40
- ):
41
- """ for video driven potrait animation
42
- """
43
- if input_image_path is not None and input_video_path is not None:
44
- args_user = {
45
- 'source_image': input_image_path,
46
- 'driving_info': input_video_path,
47
- 'flag_relative': flag_relative_input,
48
- 'flag_do_crop': flag_do_crop_input,
49
- 'flag_pasteback': flag_remap_input,
50
- 'flag_crop_driving_video': flag_crop_driving_video_input
51
- }
52
- # update config from user input
53
- self.args = update_args(self.args, args_user)
54
- self.live_portrait_wrapper.update_config(self.args.__dict__)
55
- self.cropper.update_config(self.args.__dict__)
56
- # video driven animation
57
- video_path, video_path_concat = self.execute(self.args)
58
- gr.Info("Run successfully!", duration=2)
59
- return video_path, video_path_concat,
60
- else:
61
- raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
62
-
63
- def execute_s_video(
64
- self,
65
- input_s_video_path,
66
- input_video_path,
67
- flag_relative_input,
68
- flag_do_crop_input,
69
- flag_remap_input,
70
- flag_crop_driving_video_input
71
- ):
72
- """ for video driven source to video animation
73
- """
74
- if input_s_video_path is not None and input_video_path is not None:
75
- args_user = {
76
- 'source_driving_info': input_s_video_path,
77
- 'driving_info': input_video_path,
78
- 'flag_relative': flag_relative_input,
79
- 'flag_do_crop': flag_do_crop_input,
80
- 'flag_pasteback': flag_remap_input,
81
- 'flag_crop_driving_video': flag_crop_driving_video_input
82
- }
83
- # update config from user input
84
- self.args = update_args(self.args, args_user)
85
- self.live_portrait_wrapper.update_config(self.args.__dict__)
86
- self.cropper.update_config(self.args.__dict__)
87
- # video driven animation
88
- video_path, video_path_concat = self.execute_source_video(self.args)
89
- gr.Info("Run successfully!", duration=3)
90
- return video_path, video_path_concat,
91
- else:
92
- raise gr.Error("The input source video or driving video hasn't been prepared yet 💥!", duration=5)
93
-
94
- def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True):
95
- """ for single image retargeting
96
- """
97
- # disposable feature
98
- f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
99
- self.prepare_retargeting(input_image, flag_do_crop)
100
-
101
- if input_eye_ratio is None or input_lip_ratio is None:
102
- raise gr.Error("Invalid ratio input 💥!", duration=5)
103
- else:
104
- inference_cfg = self.live_portrait_wrapper.inference_cfg
105
- x_s_user = x_s_user.to(self.live_portrait_wrapper.device)
106
- f_s_user = f_s_user.to(self.live_portrait_wrapper.device)
107
- # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
108
- combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user)
109
- eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
110
- # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
111
- combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
112
- lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
113
- num_kp = x_s_user.shape[1]
114
- # default: use x_s
115
- x_d_new = x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
116
- # D(W(f_s; x_s, x′_d))
117
- out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
118
- out = self.live_portrait_wrapper.parse_output(out['out'])[0]
119
- out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
120
- gr.Info("Run successfully!", duration=2)
121
- return out, out_to_ori_blend
122
-
123
- def prepare_retargeting(self, input_image, flag_do_crop=True):
124
- """ for single image retargeting
125
- """
126
- if input_image is not None:
127
- # gr.Info("Upload successfully!", duration=2)
128
- inference_cfg = self.live_portrait_wrapper.inference_cfg
129
- ######## process source portrait ########
130
- img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
131
- log(f"Load source image from {input_image}.")
132
- crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
133
- if flag_do_crop:
134
- I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
135
- else:
136
- I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
137
- x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
138
- R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
139
- ############################################
140
- f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
141
- x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
142
- source_lmk_user = crop_info['lmk_crop']
143
- crop_M_c2o = crop_info['M_c2o']
144
- mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
145
- return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
146
- else:
147
- # when press the clear button, go here
148
- raise gr.Error("The retargeting input hasn't been prepared yet 💥!", duration=5)
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Pipeline for gradio
5
+ """
6
+ import gradio as gr
7
+
8
+ from .config.argument_config import ArgumentConfig
9
+ from .live_portrait_pipeline import LivePortraitPipeline
10
+ from .utils.io import load_img_online
11
+ from .utils.rprint import rlog as log
12
+ from .utils.crop import prepare_paste_back, paste_back
13
+ from .utils.camera import get_rotation_matrix
14
+
15
+
16
+ def update_args(args, user_args):
17
+ """update the args according to user inputs
18
+ """
19
+ for k, v in user_args.items():
20
+ if hasattr(args, k):
21
+ setattr(args, k, v)
22
+ return args
23
+
24
+
25
+ class GradioPipeline(LivePortraitPipeline):
26
+
27
+ def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
28
+ super().__init__(inference_cfg, crop_cfg)
29
+ # self.live_portrait_wrapper = self.live_portrait_wrapper
30
+ self.args = args
31
+
32
+ def execute_video(
33
+ self,
34
+ input_image_path,
35
+ input_video_path,
36
+ flag_relative_input,
37
+ flag_do_crop_input,
38
+ flag_remap_input,
39
+ flag_crop_driving_video_input
40
+ ):
41
+ """ for video driven potrait animation
42
+ """
43
+ if input_image_path is not None and input_video_path is not None:
44
+ args_user = {
45
+ 'source_image': input_image_path,
46
+ 'driving_info': input_video_path,
47
+ 'flag_relative': flag_relative_input,
48
+ 'flag_do_crop': flag_do_crop_input,
49
+ 'flag_pasteback': flag_remap_input,
50
+ 'flag_crop_driving_video': flag_crop_driving_video_input
51
+ }
52
+ # update config from user input
53
+ self.args = update_args(self.args, args_user)
54
+ self.live_portrait_wrapper.update_config(self.args.__dict__)
55
+ self.cropper.update_config(self.args.__dict__)
56
+ # video driven animation
57
+ video_path, video_path_concat = self.execute(self.args)
58
+ gr.Info("Run successfully!", duration=2)
59
+ return video_path, video_path_concat,
60
+ else:
61
+ raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
62
+
63
+ def execute_s_video(
64
+ self,
65
+ input_s_video_path,
66
+ input_video_path,
67
+ flag_relative_input,
68
+ flag_do_crop_input,
69
+ flag_remap_input,
70
+ flag_crop_driving_video_input,
71
+ progress=gr.Progress(track_tqdm=True)
72
+ ):
73
+ """ for video driven source to video animation
74
+ """
75
+ if input_s_video_path is not None and input_video_path is not None:
76
+ args_user = {
77
+ 'source_driving_info': input_s_video_path,
78
+ 'driving_info': input_video_path,
79
+ 'flag_relative': flag_relative_input,
80
+ 'flag_do_crop': flag_do_crop_input,
81
+ 'flag_pasteback': flag_remap_input,
82
+ 'flag_crop_driving_video': flag_crop_driving_video_input
83
+ }
84
+ # update config from user input
85
+ self.args = update_args(self.args, args_user)
86
+ self.live_portrait_wrapper.update_config(self.args.__dict__)
87
+ self.cropper.update_config(self.args.__dict__)
88
+ # video driven animation
89
+ video_path, video_path_concat = self.execute_source_video(self.args)
90
+ gr.Info("Run successfully!", duration=3)
91
+ return video_path, video_path_concat,
92
+ else:
93
+ raise gr.Error("The input source video or driving video hasn't been prepared yet 💥!", duration=5)
94
+
95
+ def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True):
96
+ """ for single image retargeting
97
+ """
98
+ # disposable feature
99
+ f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
100
+ self.prepare_retargeting(input_image, flag_do_crop)
101
+
102
+ if input_eye_ratio is None or input_lip_ratio is None:
103
+ raise gr.Error("Invalid ratio input 💥!", duration=5)
104
+ else:
105
+ inference_cfg = self.live_portrait_wrapper.inference_cfg
106
+ x_s_user = x_s_user.to(self.live_portrait_wrapper.device)
107
+ f_s_user = f_s_user.to(self.live_portrait_wrapper.device)
108
+ # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
109
+ combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user)
110
+ eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
111
+ # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
112
+ combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
113
+ lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
114
+ num_kp = x_s_user.shape[1]
115
+ # default: use x_s
116
+ x_d_new = x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
117
+ # D(W(f_s; x_s, x′_d))
118
+ out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
119
+ out = self.live_portrait_wrapper.parse_output(out['out'])[0]
120
+ out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
121
+ gr.Info("Run successfully!", duration=2)
122
+ return out, out_to_ori_blend
123
+
124
+ def prepare_retargeting(self, input_image, flag_do_crop=True):
125
+ """ for single image retargeting
126
+ """
127
+ if input_image is not None:
128
+ # gr.Info("Upload successfully!", duration=2)
129
+ inference_cfg = self.live_portrait_wrapper.inference_cfg
130
+ ######## process source portrait ########
131
+ img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
132
+ log(f"Load source image from {input_image}.")
133
+ crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
134
+ if flag_do_crop:
135
+ I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
136
+ else:
137
+ I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
138
+ x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
139
+ R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
140
+ ############################################
141
+ f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
142
+ x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
143
+ source_lmk_user = crop_info['lmk_crop']
144
+ crop_M_c2o = crop_info['M_c2o']
145
+ mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
146
+ return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
147
+ else:
148
+ # when press the clear button, go here
149
+ raise gr.Error("The retargeting input hasn't been prepared yet 💥!", duration=5)