Spaces:
Sleeping
Sleeping
Commit
•
a891a57
1
Parent(s):
7454c19
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- .gitignore +17 -0
- .vscode/settings.json +19 -0
- LICENSE +21 -0
- app.py +154 -0
- assets/docs/inference.gif +0 -0
- assets/docs/showcase.gif +3 -0
- assets/docs/showcase2.gif +3 -0
- assets/examples/driving/d0.mp4 +3 -0
- assets/examples/driving/d1.mp4 +0 -0
- assets/examples/driving/d2.mp4 +0 -0
- assets/examples/driving/d3.mp4 +3 -0
- assets/examples/driving/d5.mp4 +0 -0
- assets/examples/driving/d6.mp4 +3 -0
- assets/examples/driving/d7.mp4 +0 -0
- assets/examples/driving/d8.mp4 +0 -0
- assets/examples/driving/d9.mp4 +3 -0
- assets/examples/source/s0.jpg +0 -0
- assets/examples/source/s1.jpg +0 -0
- assets/examples/source/s10.jpg +0 -0
- assets/examples/source/s2.jpg +0 -0
- assets/examples/source/s3.jpg +0 -0
- assets/examples/source/s4.jpg +0 -0
- assets/examples/source/s5.jpg +0 -0
- assets/examples/source/s6.jpg +0 -0
- assets/examples/source/s7.jpg +0 -0
- assets/examples/source/s8.jpg +0 -0
- assets/examples/source/s9.jpg +0 -0
- assets/gradio_description_animation.md +7 -0
- assets/gradio_description_retargeting.md +1 -0
- assets/gradio_description_upload.md +2 -0
- assets/gradio_title.md +10 -0
- inference.py +33 -0
- pretrained_weights/.gitkeep +0 -0
- readme.md +143 -0
- requirements.txt +22 -0
- speed.py +192 -0
- src/config/__init__.py +0 -0
- src/config/argument_config.py +44 -0
- src/config/base_config.py +29 -0
- src/config/crop_config.py +18 -0
- src/config/inference_config.py +49 -0
- src/config/models.yaml +43 -0
- src/gradio_pipeline.py +140 -0
- src/live_portrait_pipeline.py +190 -0
- src/live_portrait_wrapper.py +307 -0
- src/modules/__init__.py +0 -0
- src/modules/appearance_feature_extractor.py +48 -0
- src/modules/convnextv2.py +149 -0
- src/modules/dense_motion.py +104 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/docs/showcase.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/docs/showcase2.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/examples/driving/d0.mp4 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/examples/driving/d3.mp4 filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/examples/driving/d6.mp4 filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/examples/driving/d9.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
**/__pycache__/
|
4 |
+
*.py[cod]
|
5 |
+
**/*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
|
8 |
+
# Model weights
|
9 |
+
**/*.pth
|
10 |
+
**/*.onnx
|
11 |
+
|
12 |
+
# Ipython notebook
|
13 |
+
*.ipynb
|
14 |
+
|
15 |
+
# Temporary files or benchmark resources
|
16 |
+
animations/*
|
17 |
+
tmp/*
|
.vscode/settings.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"[python]": {
|
3 |
+
"editor.tabSize": 4
|
4 |
+
},
|
5 |
+
"files.eol": "\n",
|
6 |
+
"files.insertFinalNewline": true,
|
7 |
+
"files.trimFinalNewlines": true,
|
8 |
+
"files.trimTrailingWhitespace": true,
|
9 |
+
"files.exclude": {
|
10 |
+
"**/.git": true,
|
11 |
+
"**/.svn": true,
|
12 |
+
"**/.hg": true,
|
13 |
+
"**/CVS": true,
|
14 |
+
"**/.DS_Store": true,
|
15 |
+
"**/Thumbs.db": true,
|
16 |
+
"**/*.crswap": true,
|
17 |
+
"**/__pycache__": true
|
18 |
+
}
|
19 |
+
}
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Kuaishou Visual Generation and Interaction Center
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
The entrance of the gradio
|
5 |
+
"""
|
6 |
+
|
7 |
+
import tyro
|
8 |
+
import gradio as gr
|
9 |
+
import os.path as osp
|
10 |
+
from src.utils.helper import load_description
|
11 |
+
from src.gradio_pipeline import GradioPipeline
|
12 |
+
from src.config.crop_config import CropConfig
|
13 |
+
from src.config.argument_config import ArgumentConfig
|
14 |
+
from src.config.inference_config import InferenceConfig
|
15 |
+
|
16 |
+
|
17 |
+
def partial_fields(target_class, kwargs):
|
18 |
+
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
19 |
+
|
20 |
+
|
21 |
+
# set tyro theme
|
22 |
+
tyro.extras.set_accent_color("bright_cyan")
|
23 |
+
args = tyro.cli(ArgumentConfig)
|
24 |
+
|
25 |
+
# specify configs for inference
|
26 |
+
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
27 |
+
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
28 |
+
gradio_pipeline = GradioPipeline(
|
29 |
+
inference_cfg=inference_cfg,
|
30 |
+
crop_cfg=crop_cfg,
|
31 |
+
args=args
|
32 |
+
)
|
33 |
+
# assets
|
34 |
+
title_md = "assets/gradio_title.md"
|
35 |
+
example_portrait_dir = "assets/examples/source"
|
36 |
+
example_video_dir = "assets/examples/driving"
|
37 |
+
data_examples = [
|
38 |
+
[osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
|
39 |
+
[osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
|
40 |
+
[osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d5.mp4"), True, True, True, True],
|
41 |
+
[osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d6.mp4"), True, True, True, True],
|
42 |
+
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True, True],
|
43 |
+
]
|
44 |
+
#################### interface logic ####################
|
45 |
+
|
46 |
+
# Define components first
|
47 |
+
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
|
48 |
+
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
|
49 |
+
retargeting_input_image = gr.Image(type="numpy")
|
50 |
+
output_image = gr.Image(type="numpy")
|
51 |
+
output_image_paste_back = gr.Image(type="numpy")
|
52 |
+
output_video = gr.Video()
|
53 |
+
output_video_concat = gr.Video()
|
54 |
+
|
55 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
56 |
+
gr.HTML(load_description(title_md))
|
57 |
+
gr.Markdown(load_description("assets/gradio_description_upload.md"))
|
58 |
+
with gr.Row():
|
59 |
+
with gr.Accordion(open=True, label="Source Portrait"):
|
60 |
+
image_input = gr.Image(type="filepath")
|
61 |
+
with gr.Accordion(open=True, label="Driving Video"):
|
62 |
+
video_input = gr.Video()
|
63 |
+
gr.Markdown(load_description("assets/gradio_description_animation.md"))
|
64 |
+
with gr.Row():
|
65 |
+
with gr.Accordion(open=True, label="Animation Options"):
|
66 |
+
with gr.Row():
|
67 |
+
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
|
68 |
+
flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
|
69 |
+
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
|
70 |
+
with gr.Row():
|
71 |
+
with gr.Column():
|
72 |
+
process_button_animation = gr.Button("🚀 Animate", variant="primary")
|
73 |
+
with gr.Column():
|
74 |
+
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
|
75 |
+
with gr.Row():
|
76 |
+
with gr.Column():
|
77 |
+
with gr.Accordion(open=True, label="The animated video in the original image space"):
|
78 |
+
output_video.render()
|
79 |
+
with gr.Column():
|
80 |
+
with gr.Accordion(open=True, label="The animated video"):
|
81 |
+
output_video_concat.render()
|
82 |
+
with gr.Row():
|
83 |
+
# Examples
|
84 |
+
gr.Markdown("## You could choose the examples below ⬇️")
|
85 |
+
with gr.Row():
|
86 |
+
gr.Examples(
|
87 |
+
examples=data_examples,
|
88 |
+
inputs=[
|
89 |
+
image_input,
|
90 |
+
video_input,
|
91 |
+
flag_relative_input,
|
92 |
+
flag_do_crop_input,
|
93 |
+
flag_remap_input
|
94 |
+
],
|
95 |
+
examples_per_page=5
|
96 |
+
)
|
97 |
+
gr.Markdown(load_description("assets/gradio_description_retargeting.md"))
|
98 |
+
with gr.Row():
|
99 |
+
eye_retargeting_slider.render()
|
100 |
+
lip_retargeting_slider.render()
|
101 |
+
with gr.Row():
|
102 |
+
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
|
103 |
+
process_button_reset_retargeting = gr.ClearButton(
|
104 |
+
[
|
105 |
+
eye_retargeting_slider,
|
106 |
+
lip_retargeting_slider,
|
107 |
+
retargeting_input_image,
|
108 |
+
output_image,
|
109 |
+
output_image_paste_back
|
110 |
+
],
|
111 |
+
value="🧹 Clear"
|
112 |
+
)
|
113 |
+
with gr.Row():
|
114 |
+
with gr.Column():
|
115 |
+
with gr.Accordion(open=True, label="Retargeting Input"):
|
116 |
+
retargeting_input_image.render()
|
117 |
+
with gr.Column():
|
118 |
+
with gr.Accordion(open=True, label="Retargeting Result"):
|
119 |
+
output_image.render()
|
120 |
+
with gr.Column():
|
121 |
+
with gr.Accordion(open=True, label="Paste-back Result"):
|
122 |
+
output_image_paste_back.render()
|
123 |
+
# binding functions for buttons
|
124 |
+
process_button_retargeting.click(
|
125 |
+
fn=gradio_pipeline.execute_image,
|
126 |
+
inputs=[eye_retargeting_slider, lip_retargeting_slider],
|
127 |
+
outputs=[output_image, output_image_paste_back],
|
128 |
+
show_progress=True
|
129 |
+
)
|
130 |
+
process_button_animation.click(
|
131 |
+
fn=gradio_pipeline.execute_video,
|
132 |
+
inputs=[
|
133 |
+
image_input,
|
134 |
+
video_input,
|
135 |
+
flag_relative_input,
|
136 |
+
flag_do_crop_input,
|
137 |
+
flag_remap_input
|
138 |
+
],
|
139 |
+
outputs=[output_video, output_video_concat],
|
140 |
+
show_progress=True
|
141 |
+
)
|
142 |
+
image_input.change(
|
143 |
+
fn=gradio_pipeline.prepare_retargeting,
|
144 |
+
inputs=image_input,
|
145 |
+
outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
|
146 |
+
)
|
147 |
+
|
148 |
+
##########################################################
|
149 |
+
|
150 |
+
demo.launch(
|
151 |
+
server_name=args.server_name,
|
152 |
+
server_port=args.server_port,
|
153 |
+
share=args.share,
|
154 |
+
)
|
assets/docs/inference.gif
ADDED
assets/docs/showcase.gif
ADDED
Git LFS Details
|
assets/docs/showcase2.gif
ADDED
Git LFS Details
|
assets/examples/driving/d0.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:63f6f9962e1fdf6e6722172e7a18155204858d5d5ce3b1e0646c150360c33bed
|
3 |
+
size 2958395
|
assets/examples/driving/d1.mp4
ADDED
Binary file (48.8 kB). View file
|
|
assets/examples/driving/d2.mp4
ADDED
Binary file (47.8 kB). View file
|
|
assets/examples/driving/d3.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef5c86e49b1b43dcb1449b499eb5a7f0cbae2f78aec08b5598193be1e4257099
|
3 |
+
size 1430968
|
assets/examples/driving/d5.mp4
ADDED
Binary file (135 kB). View file
|
|
assets/examples/driving/d6.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:00e3ea79bbf28cbdc4fbb67ec655d9a0fe876e880ec45af55ae481348d0c0fff
|
3 |
+
size 1967790
|
assets/examples/driving/d7.mp4
ADDED
Binary file (185 kB). View file
|
|
assets/examples/driving/d8.mp4
ADDED
Binary file (312 kB). View file
|
|
assets/examples/driving/d9.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a414aa1d547be35306d692065a2157434bf40a6025ba8e30ce12e5bb322cc33
|
3 |
+
size 2257929
|
assets/examples/source/s0.jpg
ADDED
assets/examples/source/s1.jpg
ADDED
assets/examples/source/s10.jpg
ADDED
assets/examples/source/s2.jpg
ADDED
assets/examples/source/s3.jpg
ADDED
assets/examples/source/s4.jpg
ADDED
assets/examples/source/s5.jpg
ADDED
assets/examples/source/s6.jpg
ADDED
assets/examples/source/s7.jpg
ADDED
assets/examples/source/s8.jpg
ADDED
assets/examples/source/s9.jpg
ADDED
assets/gradio_description_animation.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<span style="font-size: 1.2em;">🔥 To animate the source portrait with the driving video, please follow these steps:</span>
|
2 |
+
<div style="font-size: 1.2em; margin-left: 20px;">
|
3 |
+
1. Specify the options in the <strong>Animation Options</strong> section. We recommend checking the <strong>do crop</strong> option when facial areas occupy a relatively small portion of your image.
|
4 |
+
</div>
|
5 |
+
<div style="font-size: 1.2em; margin-left: 20px;">
|
6 |
+
2. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
|
7 |
+
</div>
|
assets/gradio_description_retargeting.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
<span style="font-size: 1.2em;">🔥 To change the target eyes-open and lip-open ratio of the source portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
|
assets/gradio_description_upload.md
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
## 🤗 This is the official gradio demo for **LivePortrait**.
|
2 |
+
<div style="font-size: 1.2em;">Please upload or use the webcam to get a source portrait to the <strong>Source Portrait</strong> field and a driving video to the <strong>Driving Video</strong> field.</div>
|
assets/gradio_title.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
2 |
+
<div>
|
3 |
+
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
|
4 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
|
5 |
+
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
|
6 |
+
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
|
7 |
+
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
|
8 |
+
</div>
|
9 |
+
</div>
|
10 |
+
</div>
|
inference.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
import tyro
|
4 |
+
from src.config.argument_config import ArgumentConfig
|
5 |
+
from src.config.inference_config import InferenceConfig
|
6 |
+
from src.config.crop_config import CropConfig
|
7 |
+
from src.live_portrait_pipeline import LivePortraitPipeline
|
8 |
+
|
9 |
+
|
10 |
+
def partial_fields(target_class, kwargs):
|
11 |
+
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
# set tyro theme
|
16 |
+
tyro.extras.set_accent_color("bright_cyan")
|
17 |
+
args = tyro.cli(ArgumentConfig)
|
18 |
+
|
19 |
+
# specify configs for inference
|
20 |
+
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
21 |
+
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
22 |
+
|
23 |
+
live_portrait_pipeline = LivePortraitPipeline(
|
24 |
+
inference_cfg=inference_cfg,
|
25 |
+
crop_cfg=crop_cfg
|
26 |
+
)
|
27 |
+
|
28 |
+
# run
|
29 |
+
live_portrait_pipeline.execute(args)
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == '__main__':
|
33 |
+
main()
|
pretrained_weights/.gitkeep
ADDED
File without changes
|
readme.md
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h1 align="center">LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
|
2 |
+
|
3 |
+
<div align='center'>
|
4 |
+
<a href='https://github.com/cleardusk' target='_blank'><strong>Jianzhu Guo</strong></a><sup> 1†</sup> 
|
5 |
+
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Dingyun Zhang</strong></a><sup> 1,2</sup> 
|
6 |
+
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Xiaoqiang Liu</strong></a><sup> 1</sup> 
|
7 |
+
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Zhizhou Zhong</strong></a><sup> 1,3</sup> 
|
8 |
+
<a href='https://scholar.google.com.hk/citations?user=_8k1ubAAAAAJ' target='_blank'><strong>Yuan Zhang</strong></a><sup> 1</sup> 
|
9 |
+
</div>
|
10 |
+
|
11 |
+
<div align='center'>
|
12 |
+
<a href='https://scholar.google.com/citations?user=P6MraaYAAAAJ' target='_blank'><strong>Pengfei Wan</strong></a><sup> 1</sup> 
|
13 |
+
<a href='https://openreview.net/profile?id=~Di_ZHANG3' target='_blank'><strong>Di Zhang</strong></a><sup> 1</sup> 
|
14 |
+
</div>
|
15 |
+
|
16 |
+
<div align='center'>
|
17 |
+
<sup>1 </sup>Kuaishou Technology  <sup>2 </sup>University of Science and Technology of China  <sup>3 </sup>Fudan University 
|
18 |
+
</div>
|
19 |
+
|
20 |
+
<br>
|
21 |
+
<div align="center">
|
22 |
+
<!-- <a href='LICENSE'><img src='https://img.shields.io/badge/license-MIT-yellow'></a> -->
|
23 |
+
<a href='https://liveportrait.github.io'><img src='https://img.shields.io/badge/Project-Homepage-green'></a>
|
24 |
+
<a href='https://arxiv.org/pdf/2407.03168'><img src='https://img.shields.io/badge/Paper-arXiv-red'></a>
|
25 |
+
</div>
|
26 |
+
<br>
|
27 |
+
|
28 |
+
<p align="center">
|
29 |
+
<img src="./assets/docs/showcase2.gif" alt="showcase">
|
30 |
+
<br>
|
31 |
+
🔥 For more results, visit our <a href="https://liveportrait.github.io/"><strong>homepage</strong></a> 🔥
|
32 |
+
</p>
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
## 🔥 Updates
|
37 |
+
- **`2024/07/04`**: 🔥 We released the initial version of the inference code and models. Continuous updates, stay tuned!
|
38 |
+
- **`2024/07/04`**: 😊 We released the [homepage](https://liveportrait.github.io) and technical report on [arXiv](https://arxiv.org/pdf/2407.03168).
|
39 |
+
|
40 |
+
## Introduction
|
41 |
+
This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168).
|
42 |
+
We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖.
|
43 |
+
|
44 |
+
## 🔥 Getting Started
|
45 |
+
### 1. Clone the code and prepare the environment
|
46 |
+
```bash
|
47 |
+
git clone https://github.com/KwaiVGI/LivePortrait
|
48 |
+
cd LivePortrait
|
49 |
+
|
50 |
+
# create env using conda
|
51 |
+
conda create -n LivePortrait python==3.9.18
|
52 |
+
conda activate LivePortrait
|
53 |
+
# install dependencies with pip
|
54 |
+
pip install -r requirements.txt
|
55 |
+
```
|
56 |
+
|
57 |
+
### 2. Download pretrained weights
|
58 |
+
Download our pretrained LivePortrait weights and face detection models of InsightFace from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). We have packed all weights in one directory 😊. Unzip and place them in `./pretrained_weights` ensuring the directory structure is as follows:
|
59 |
+
```text
|
60 |
+
pretrained_weights
|
61 |
+
├── insightface
|
62 |
+
│ └── models
|
63 |
+
│ └── buffalo_l
|
64 |
+
│ ├── 2d106det.onnx
|
65 |
+
│ └── det_10g.onnx
|
66 |
+
└── liveportrait
|
67 |
+
├── base_models
|
68 |
+
│ ├── appearance_feature_extractor.pth
|
69 |
+
│ ├── motion_extractor.pth
|
70 |
+
│ ├── spade_generator.pth
|
71 |
+
│ └── warping_module.pth
|
72 |
+
├── landmark.onnx
|
73 |
+
└── retargeting_models
|
74 |
+
└── stitching_retargeting_module.pth
|
75 |
+
```
|
76 |
+
|
77 |
+
### 3. Inference 🚀
|
78 |
+
|
79 |
+
```bash
|
80 |
+
python inference.py
|
81 |
+
```
|
82 |
+
|
83 |
+
If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image, and generated result.
|
84 |
+
|
85 |
+
<p align="center">
|
86 |
+
<img src="./assets/docs/inference.gif" alt="image">
|
87 |
+
</p>
|
88 |
+
|
89 |
+
Or, you can change the input by specifying the `-s` and `-d` arguments:
|
90 |
+
|
91 |
+
```bash
|
92 |
+
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4
|
93 |
+
|
94 |
+
# or disable pasting back
|
95 |
+
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 --no_flag_pasteback
|
96 |
+
|
97 |
+
# more options to see
|
98 |
+
python inference.py -h
|
99 |
+
```
|
100 |
+
|
101 |
+
**More interesting results can be found in our [Homepage](https://liveportrait.github.io)** 😊
|
102 |
+
|
103 |
+
### 4. Gradio interface
|
104 |
+
|
105 |
+
We also provide a Gradio interface for a better experience, just run by:
|
106 |
+
|
107 |
+
```bash
|
108 |
+
python app.py
|
109 |
+
```
|
110 |
+
|
111 |
+
### 5. Inference speed evaluation 🚀🚀🚀
|
112 |
+
We have also provided a script to evaluate the inference speed of each module:
|
113 |
+
|
114 |
+
```bash
|
115 |
+
python speed.py
|
116 |
+
```
|
117 |
+
|
118 |
+
Below are the results of inferring one frame on an RTX 4090 GPU using the native PyTorch framework with `torch.compile`:
|
119 |
+
|
120 |
+
| Model | Parameters(M) | Model Size(MB) | Inference(ms) |
|
121 |
+
|-----------------------------------|:-------------:|:--------------:|:-------------:|
|
122 |
+
| Appearance Feature Extractor | 0.84 | 3.3 | 0.82 |
|
123 |
+
| Motion Extractor | 28.12 | 108 | 0.84 |
|
124 |
+
| Spade Generator | 55.37 | 212 | 7.59 |
|
125 |
+
| Warping Module | 45.53 | 174 | 5.21 |
|
126 |
+
| Stitching and Retargeting Modules| 0.23 | 2.3 | 0.31 |
|
127 |
+
|
128 |
+
*Note: the listed values of Stitching and Retargeting Modules represent the combined parameter counts and the total sequential inference time of three MLP networks.*
|
129 |
+
|
130 |
+
|
131 |
+
## Acknowledgements
|
132 |
+
We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) repositories, for their open research and contributions.
|
133 |
+
|
134 |
+
## Citation 💖
|
135 |
+
If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
|
136 |
+
```bibtex
|
137 |
+
@article{guo2024live,
|
138 |
+
title = {LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control},
|
139 |
+
author = {Jianzhu Guo and Dingyun Zhang and Xiaoqiang Liu and Zhizhou Zhong and Yuan Zhang and Pengfei Wan and Di Zhang},
|
140 |
+
year = {2024},
|
141 |
+
journal = {arXiv preprint:2407.03168},
|
142 |
+
}
|
143 |
+
```
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
2 |
+
torch==2.3.0
|
3 |
+
torchvision==0.18.0
|
4 |
+
torchaudio==2.3.0
|
5 |
+
|
6 |
+
numpy==1.26.4
|
7 |
+
pyyaml==6.0.1
|
8 |
+
opencv-python==4.10.0.84
|
9 |
+
scipy==1.13.1
|
10 |
+
imageio==2.34.2
|
11 |
+
lmdb==1.4.1
|
12 |
+
tqdm==4.66.4
|
13 |
+
rich==13.7.1
|
14 |
+
ffmpeg==1.4
|
15 |
+
onnxruntime-gpu==1.18.0
|
16 |
+
onnx==1.16.1
|
17 |
+
scikit-image==0.24.0
|
18 |
+
albumentations==1.4.10
|
19 |
+
matplotlib==3.9.0
|
20 |
+
imageio-ffmpeg==0.5.1
|
21 |
+
tyro==0.8.5
|
22 |
+
gradio==4.37.1
|
speed.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Benchmark the inference speed of each module in LivePortrait.
|
5 |
+
|
6 |
+
TODO: heavy GPT style, need to refactor
|
7 |
+
"""
|
8 |
+
|
9 |
+
import yaml
|
10 |
+
import torch
|
11 |
+
import time
|
12 |
+
import numpy as np
|
13 |
+
from src.utils.helper import load_model, concat_feat
|
14 |
+
from src.config.inference_config import InferenceConfig
|
15 |
+
|
16 |
+
|
17 |
+
def initialize_inputs(batch_size=1):
|
18 |
+
"""
|
19 |
+
Generate random input tensors and move them to GPU
|
20 |
+
"""
|
21 |
+
feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half()
|
22 |
+
kp_source = torch.randn(batch_size, 21, 3).cuda().half()
|
23 |
+
kp_driving = torch.randn(batch_size, 21, 3).cuda().half()
|
24 |
+
source_image = torch.randn(batch_size, 3, 256, 256).cuda().half()
|
25 |
+
generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half()
|
26 |
+
eye_close_ratio = torch.randn(batch_size, 3).cuda().half()
|
27 |
+
lip_close_ratio = torch.randn(batch_size, 2).cuda().half()
|
28 |
+
feat_stitching = concat_feat(kp_source, kp_driving).half()
|
29 |
+
feat_eye = concat_feat(kp_source, eye_close_ratio).half()
|
30 |
+
feat_lip = concat_feat(kp_source, lip_close_ratio).half()
|
31 |
+
|
32 |
+
inputs = {
|
33 |
+
'feature_3d': feature_3d,
|
34 |
+
'kp_source': kp_source,
|
35 |
+
'kp_driving': kp_driving,
|
36 |
+
'source_image': source_image,
|
37 |
+
'generator_input': generator_input,
|
38 |
+
'feat_stitching': feat_stitching,
|
39 |
+
'feat_eye': feat_eye,
|
40 |
+
'feat_lip': feat_lip
|
41 |
+
}
|
42 |
+
|
43 |
+
return inputs
|
44 |
+
|
45 |
+
|
46 |
+
def load_and_compile_models(cfg, model_config):
|
47 |
+
"""
|
48 |
+
Load and compile models for inference
|
49 |
+
"""
|
50 |
+
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
|
51 |
+
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
|
52 |
+
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
|
53 |
+
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
|
54 |
+
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
|
55 |
+
|
56 |
+
models_with_params = [
|
57 |
+
('Appearance Feature Extractor', appearance_feature_extractor),
|
58 |
+
('Motion Extractor', motion_extractor),
|
59 |
+
('Warping Network', warping_module),
|
60 |
+
('SPADE Decoder', spade_generator)
|
61 |
+
]
|
62 |
+
|
63 |
+
compiled_models = {}
|
64 |
+
for name, model in models_with_params:
|
65 |
+
model = model.half()
|
66 |
+
model = torch.compile(model, mode='max-autotune') # Optimize for inference
|
67 |
+
model.eval() # Switch to evaluation mode
|
68 |
+
compiled_models[name] = model
|
69 |
+
|
70 |
+
retargeting_models = ['stitching', 'eye', 'lip']
|
71 |
+
for retarget in retargeting_models:
|
72 |
+
module = stitching_retargeting_module[retarget].half()
|
73 |
+
module = torch.compile(module, mode='max-autotune') # Optimize for inference
|
74 |
+
module.eval() # Switch to evaluation mode
|
75 |
+
stitching_retargeting_module[retarget] = module
|
76 |
+
|
77 |
+
return compiled_models, stitching_retargeting_module
|
78 |
+
|
79 |
+
|
80 |
+
def warm_up_models(compiled_models, stitching_retargeting_module, inputs):
|
81 |
+
"""
|
82 |
+
Warm up models to prepare them for benchmarking
|
83 |
+
"""
|
84 |
+
print("Warm up start!")
|
85 |
+
with torch.no_grad():
|
86 |
+
for _ in range(10):
|
87 |
+
compiled_models['Appearance Feature Extractor'](inputs['source_image'])
|
88 |
+
compiled_models['Motion Extractor'](inputs['source_image'])
|
89 |
+
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
|
90 |
+
compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
|
91 |
+
stitching_retargeting_module['stitching'](inputs['feat_stitching'])
|
92 |
+
stitching_retargeting_module['eye'](inputs['feat_eye'])
|
93 |
+
stitching_retargeting_module['lip'](inputs['feat_lip'])
|
94 |
+
print("Warm up end!")
|
95 |
+
|
96 |
+
|
97 |
+
def measure_inference_times(compiled_models, stitching_retargeting_module, inputs):
|
98 |
+
"""
|
99 |
+
Measure inference times for each model
|
100 |
+
"""
|
101 |
+
times = {name: [] for name in compiled_models.keys()}
|
102 |
+
times['Retargeting Models'] = []
|
103 |
+
|
104 |
+
overall_times = []
|
105 |
+
|
106 |
+
with torch.no_grad():
|
107 |
+
for _ in range(100):
|
108 |
+
torch.cuda.synchronize()
|
109 |
+
overall_start = time.time()
|
110 |
+
|
111 |
+
start = time.time()
|
112 |
+
compiled_models['Appearance Feature Extractor'](inputs['source_image'])
|
113 |
+
torch.cuda.synchronize()
|
114 |
+
times['Appearance Feature Extractor'].append(time.time() - start)
|
115 |
+
|
116 |
+
start = time.time()
|
117 |
+
compiled_models['Motion Extractor'](inputs['source_image'])
|
118 |
+
torch.cuda.synchronize()
|
119 |
+
times['Motion Extractor'].append(time.time() - start)
|
120 |
+
|
121 |
+
start = time.time()
|
122 |
+
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
|
123 |
+
torch.cuda.synchronize()
|
124 |
+
times['Warping Network'].append(time.time() - start)
|
125 |
+
|
126 |
+
start = time.time()
|
127 |
+
compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
|
128 |
+
torch.cuda.synchronize()
|
129 |
+
times['SPADE Decoder'].append(time.time() - start)
|
130 |
+
|
131 |
+
start = time.time()
|
132 |
+
stitching_retargeting_module['stitching'](inputs['feat_stitching'])
|
133 |
+
stitching_retargeting_module['eye'](inputs['feat_eye'])
|
134 |
+
stitching_retargeting_module['lip'](inputs['feat_lip'])
|
135 |
+
torch.cuda.synchronize()
|
136 |
+
times['Retargeting Models'].append(time.time() - start)
|
137 |
+
|
138 |
+
overall_times.append(time.time() - overall_start)
|
139 |
+
|
140 |
+
return times, overall_times
|
141 |
+
|
142 |
+
|
143 |
+
def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times):
|
144 |
+
"""
|
145 |
+
Print benchmark results with average and standard deviation of inference times
|
146 |
+
"""
|
147 |
+
average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()}
|
148 |
+
std_times = {name: np.std(times[name]) * 1000 for name in times.keys()}
|
149 |
+
|
150 |
+
for name, model in compiled_models.items():
|
151 |
+
num_params = sum(p.numel() for p in model.parameters())
|
152 |
+
num_params_in_millions = num_params / 1e6
|
153 |
+
print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M")
|
154 |
+
|
155 |
+
for index, retarget in enumerate(retargeting_models):
|
156 |
+
num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters())
|
157 |
+
num_params_in_millions = num_params / 1e6
|
158 |
+
print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M")
|
159 |
+
|
160 |
+
for name, avg_time in average_times.items():
|
161 |
+
std_time = std_times[name]
|
162 |
+
print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)")
|
163 |
+
|
164 |
+
|
165 |
+
def main():
|
166 |
+
"""
|
167 |
+
Main function to benchmark speed and model parameters
|
168 |
+
"""
|
169 |
+
# Sample input tensors
|
170 |
+
inputs = initialize_inputs()
|
171 |
+
|
172 |
+
# Load configuration
|
173 |
+
cfg = InferenceConfig(device_id=0)
|
174 |
+
model_config_path = cfg.models_config
|
175 |
+
with open(model_config_path, 'r') as file:
|
176 |
+
model_config = yaml.safe_load(file)
|
177 |
+
|
178 |
+
# Load and compile models
|
179 |
+
compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config)
|
180 |
+
|
181 |
+
# Warm up models
|
182 |
+
warm_up_models(compiled_models, stitching_retargeting_module, inputs)
|
183 |
+
|
184 |
+
# Measure inference times
|
185 |
+
times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs)
|
186 |
+
|
187 |
+
# Print benchmark results
|
188 |
+
print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times)
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
main()
|
src/config/__init__.py
ADDED
File without changes
|
src/config/argument_config.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
config for user
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os.path as osp
|
8 |
+
from dataclasses import dataclass
|
9 |
+
import tyro
|
10 |
+
from typing_extensions import Annotated
|
11 |
+
from .base_config import PrintableConfig, make_abs_path
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass(repr=False) # use repr from PrintableConfig
|
15 |
+
class ArgumentConfig(PrintableConfig):
|
16 |
+
########## input arguments ##########
|
17 |
+
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait
|
18 |
+
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
|
19 |
+
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
|
20 |
+
#####################################
|
21 |
+
|
22 |
+
########## inference arguments ##########
|
23 |
+
device_id: int = 0
|
24 |
+
flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
|
25 |
+
flag_eye_retargeting: bool = False
|
26 |
+
flag_lip_retargeting: bool = False
|
27 |
+
flag_stitching: bool = True # we recommend setting it to True!
|
28 |
+
flag_relative: bool = True # whether to use relative motion
|
29 |
+
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
|
30 |
+
flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space
|
31 |
+
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
32 |
+
#########################################
|
33 |
+
|
34 |
+
########## crop arguments ##########
|
35 |
+
dsize: int = 512
|
36 |
+
scale: float = 2.3
|
37 |
+
vx_ratio: float = 0 # vx ratio
|
38 |
+
vy_ratio: float = -0.125 # vy ratio +up, -down
|
39 |
+
####################################
|
40 |
+
|
41 |
+
########## gradio arguments ##########
|
42 |
+
server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890
|
43 |
+
share: bool = True
|
44 |
+
server_name: str = "0.0.0.0"
|
src/config/base_config.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
pretty printing class
|
5 |
+
"""
|
6 |
+
|
7 |
+
from __future__ import annotations
|
8 |
+
import os.path as osp
|
9 |
+
from typing import Tuple
|
10 |
+
|
11 |
+
|
12 |
+
def make_abs_path(fn):
|
13 |
+
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
|
14 |
+
|
15 |
+
|
16 |
+
class PrintableConfig: # pylint: disable=too-few-public-methods
|
17 |
+
"""Printable Config defining str function"""
|
18 |
+
|
19 |
+
def __repr__(self):
|
20 |
+
lines = [self.__class__.__name__ + ":"]
|
21 |
+
for key, val in vars(self).items():
|
22 |
+
if isinstance(val, Tuple):
|
23 |
+
flattened_val = "["
|
24 |
+
for item in val:
|
25 |
+
flattened_val += str(item) + "\n"
|
26 |
+
flattened_val = flattened_val.rstrip("\n")
|
27 |
+
val = flattened_val + "]"
|
28 |
+
lines += f"{key}: {str(val)}".split("\n")
|
29 |
+
return "\n ".join(lines)
|
src/config/crop_config.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
parameters used for crop faces
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os.path as osp
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import Union, List
|
10 |
+
from .base_config import PrintableConfig
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass(repr=False) # use repr from PrintableConfig
|
14 |
+
class CropConfig(PrintableConfig):
|
15 |
+
dsize: int = 512 # crop size
|
16 |
+
scale: float = 2.3 # scale factor
|
17 |
+
vx_ratio: float = 0 # vx ratio
|
18 |
+
vy_ratio: float = -0.125 # vy ratio +up, -down
|
src/config/inference_config.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
config dataclass used for inference
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os.path as osp
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import Literal, Tuple
|
10 |
+
from .base_config import PrintableConfig, make_abs_path
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass(repr=False) # use repr from PrintableConfig
|
14 |
+
class InferenceConfig(PrintableConfig):
|
15 |
+
models_config: str = make_abs_path('./models.yaml') # portrait animation config
|
16 |
+
checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint
|
17 |
+
checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint
|
18 |
+
checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint
|
19 |
+
checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint
|
20 |
+
|
21 |
+
checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint
|
22 |
+
flag_use_half_precision: bool = True # whether to use half precision
|
23 |
+
|
24 |
+
flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
|
25 |
+
lip_zero_threshold: float = 0.03
|
26 |
+
|
27 |
+
flag_eye_retargeting: bool = False
|
28 |
+
flag_lip_retargeting: bool = False
|
29 |
+
flag_stitching: bool = True # we recommend setting it to True!
|
30 |
+
|
31 |
+
flag_relative: bool = True # whether to use relative motion
|
32 |
+
anchor_frame: int = 0 # set this value if find_best_frame is True
|
33 |
+
|
34 |
+
input_shape: Tuple[int, int] = (256, 256) # input shape
|
35 |
+
output_format: Literal['mp4', 'gif'] = 'mp4' # output video format
|
36 |
+
output_fps: int = 30 # fps for output video
|
37 |
+
crf: int = 15 # crf for output video
|
38 |
+
|
39 |
+
flag_write_result: bool = True # whether to write output video
|
40 |
+
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
|
41 |
+
mask_crop = None
|
42 |
+
flag_write_gif: bool = False
|
43 |
+
size_gif: int = 256
|
44 |
+
ref_max_shape: int = 1280
|
45 |
+
ref_shape_n: int = 2
|
46 |
+
|
47 |
+
device_id: int = 0
|
48 |
+
flag_do_crop: bool = False # whether to crop the source portrait to the face-cropping space
|
49 |
+
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
src/config/models.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_params:
|
2 |
+
appearance_feature_extractor_params: # the F in the paper
|
3 |
+
image_channel: 3
|
4 |
+
block_expansion: 64
|
5 |
+
num_down_blocks: 2
|
6 |
+
max_features: 512
|
7 |
+
reshape_channel: 32
|
8 |
+
reshape_depth: 16
|
9 |
+
num_resblocks: 6
|
10 |
+
motion_extractor_params: # the M in the paper
|
11 |
+
num_kp: 21
|
12 |
+
backbone: convnextv2_tiny
|
13 |
+
warping_module_params: # the W in the paper
|
14 |
+
num_kp: 21
|
15 |
+
block_expansion: 64
|
16 |
+
max_features: 512
|
17 |
+
num_down_blocks: 2
|
18 |
+
reshape_channel: 32
|
19 |
+
estimate_occlusion_map: True
|
20 |
+
dense_motion_params:
|
21 |
+
block_expansion: 32
|
22 |
+
max_features: 1024
|
23 |
+
num_blocks: 5
|
24 |
+
reshape_depth: 16
|
25 |
+
compress: 4
|
26 |
+
spade_generator_params: # the G in the paper
|
27 |
+
upscale: 2 # represents upsample factor 256x256 -> 512x512
|
28 |
+
block_expansion: 64
|
29 |
+
max_features: 512
|
30 |
+
num_down_blocks: 2
|
31 |
+
stitching_retargeting_module_params: # the S in the paper
|
32 |
+
stitching:
|
33 |
+
input_size: 126 # (21*3)*2
|
34 |
+
hidden_sizes: [128, 128, 64]
|
35 |
+
output_size: 65 # (21*3)+2(tx,ty)
|
36 |
+
lip:
|
37 |
+
input_size: 65 # (21*3)+2
|
38 |
+
hidden_sizes: [128, 128, 64]
|
39 |
+
output_size: 63 # (21*3)
|
40 |
+
eye:
|
41 |
+
input_size: 66 # (21*3)+3
|
42 |
+
hidden_sizes: [256, 256, 128, 128, 64]
|
43 |
+
output_size: 63 # (21*3)
|
src/gradio_pipeline.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Pipeline for gradio
|
5 |
+
"""
|
6 |
+
import gradio as gr
|
7 |
+
from .config.argument_config import ArgumentConfig
|
8 |
+
from .live_portrait_pipeline import LivePortraitPipeline
|
9 |
+
from .utils.io import load_img_online
|
10 |
+
from .utils.rprint import rlog as log
|
11 |
+
from .utils.crop import prepare_paste_back, paste_back
|
12 |
+
from .utils.camera import get_rotation_matrix
|
13 |
+
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
14 |
+
|
15 |
+
def update_args(args, user_args):
|
16 |
+
"""update the args according to user inputs
|
17 |
+
"""
|
18 |
+
for k, v in user_args.items():
|
19 |
+
if hasattr(args, k):
|
20 |
+
setattr(args, k, v)
|
21 |
+
return args
|
22 |
+
|
23 |
+
class GradioPipeline(LivePortraitPipeline):
|
24 |
+
|
25 |
+
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
|
26 |
+
super().__init__(inference_cfg, crop_cfg)
|
27 |
+
# self.live_portrait_wrapper = self.live_portrait_wrapper
|
28 |
+
self.args = args
|
29 |
+
# for single image retargeting
|
30 |
+
self.start_prepare = False
|
31 |
+
self.f_s_user = None
|
32 |
+
self.x_c_s_info_user = None
|
33 |
+
self.x_s_user = None
|
34 |
+
self.source_lmk_user = None
|
35 |
+
self.mask_ori = None
|
36 |
+
self.img_rgb = None
|
37 |
+
self.crop_M_c2o = None
|
38 |
+
|
39 |
+
|
40 |
+
def execute_video(
|
41 |
+
self,
|
42 |
+
input_image_path,
|
43 |
+
input_video_path,
|
44 |
+
flag_relative_input,
|
45 |
+
flag_do_crop_input,
|
46 |
+
flag_remap_input,
|
47 |
+
):
|
48 |
+
""" for video driven potrait animation
|
49 |
+
"""
|
50 |
+
if input_image_path is not None and input_video_path is not None:
|
51 |
+
args_user = {
|
52 |
+
'source_image': input_image_path,
|
53 |
+
'driving_info': input_video_path,
|
54 |
+
'flag_relative': flag_relative_input,
|
55 |
+
'flag_do_crop': flag_do_crop_input,
|
56 |
+
'flag_pasteback': flag_remap_input,
|
57 |
+
}
|
58 |
+
# update config from user input
|
59 |
+
self.args = update_args(self.args, args_user)
|
60 |
+
self.live_portrait_wrapper.update_config(self.args.__dict__)
|
61 |
+
self.cropper.update_config(self.args.__dict__)
|
62 |
+
# video driven animation
|
63 |
+
video_path, video_path_concat = self.execute(self.args)
|
64 |
+
gr.Info("Run successfully!", duration=2)
|
65 |
+
return video_path, video_path_concat,
|
66 |
+
else:
|
67 |
+
raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
|
68 |
+
|
69 |
+
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
|
70 |
+
""" for single image retargeting
|
71 |
+
"""
|
72 |
+
if input_eye_ratio is None or input_eye_ratio is None:
|
73 |
+
raise gr.Error("Invalid ratio input 💥!", duration=5)
|
74 |
+
elif self.f_s_user is None:
|
75 |
+
if self.start_prepare:
|
76 |
+
raise gr.Error(
|
77 |
+
"The source portrait is under processing 💥! Please wait for a second.",
|
78 |
+
duration=5
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
raise gr.Error(
|
82 |
+
"The source portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
|
83 |
+
duration=5
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
87 |
+
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
|
88 |
+
eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
|
89 |
+
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
90 |
+
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
|
91 |
+
lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
|
92 |
+
num_kp = self.x_s_user.shape[1]
|
93 |
+
# default: use x_s
|
94 |
+
x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
|
95 |
+
# D(W(f_s; x_s, x′_d))
|
96 |
+
out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
|
97 |
+
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
98 |
+
out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
|
99 |
+
gr.Info("Run successfully!", duration=2)
|
100 |
+
return out, out_to_ori_blend
|
101 |
+
|
102 |
+
|
103 |
+
def prepare_retargeting(self, input_image_path, flag_do_crop = True):
|
104 |
+
""" for single image retargeting
|
105 |
+
"""
|
106 |
+
if input_image_path is not None:
|
107 |
+
gr.Info("Upload successfully!", duration=2)
|
108 |
+
self.start_prepare = True
|
109 |
+
inference_cfg = self.live_portrait_wrapper.cfg
|
110 |
+
######## process source portrait ########
|
111 |
+
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
|
112 |
+
log(f"Load source image from {input_image_path}.")
|
113 |
+
crop_info = self.cropper.crop_single_image(img_rgb)
|
114 |
+
if flag_do_crop:
|
115 |
+
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
|
116 |
+
else:
|
117 |
+
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
|
118 |
+
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
119 |
+
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
120 |
+
############################################
|
121 |
+
|
122 |
+
# record global info for next time use
|
123 |
+
self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
124 |
+
self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
125 |
+
self.x_s_info_user = x_s_info
|
126 |
+
self.source_lmk_user = crop_info['lmk_crop']
|
127 |
+
self.img_rgb = img_rgb
|
128 |
+
self.crop_M_c2o = crop_info['M_c2o']
|
129 |
+
self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
130 |
+
# update slider
|
131 |
+
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
|
132 |
+
eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
|
133 |
+
lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
|
134 |
+
lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
|
135 |
+
# for vis
|
136 |
+
self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0]
|
137 |
+
return eye_close_ratio, lip_close_ratio, self.I_s_vis
|
138 |
+
else:
|
139 |
+
# when press the clear button, go here
|
140 |
+
return 0.8, 0.8, self.I_s_vis
|
src/live_portrait_pipeline.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Pipeline of LivePortrait
|
5 |
+
"""
|
6 |
+
|
7 |
+
# TODO:
|
8 |
+
# 1. 当前假定所有的模板都是已经裁好的,需要修改下
|
9 |
+
# 2. pick样例图 source + driving
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
import pickle
|
14 |
+
import os.path as osp
|
15 |
+
from rich.progress import track
|
16 |
+
|
17 |
+
from .config.argument_config import ArgumentConfig
|
18 |
+
from .config.inference_config import InferenceConfig
|
19 |
+
from .config.crop_config import CropConfig
|
20 |
+
from .utils.cropper import Cropper
|
21 |
+
from .utils.camera import get_rotation_matrix
|
22 |
+
from .utils.video import images2video, concat_frames
|
23 |
+
from .utils.crop import _transform_img, prepare_paste_back, paste_back
|
24 |
+
from .utils.retargeting_utils import calc_lip_close_ratio
|
25 |
+
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
|
26 |
+
from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template
|
27 |
+
from .utils.rprint import rlog as log
|
28 |
+
from .live_portrait_wrapper import LivePortraitWrapper
|
29 |
+
|
30 |
+
|
31 |
+
def make_abs_path(fn):
|
32 |
+
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
|
33 |
+
|
34 |
+
|
35 |
+
class LivePortraitPipeline(object):
|
36 |
+
|
37 |
+
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
|
38 |
+
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
|
39 |
+
self.cropper = Cropper(crop_cfg=crop_cfg)
|
40 |
+
|
41 |
+
def execute(self, args: ArgumentConfig):
|
42 |
+
inference_cfg = self.live_portrait_wrapper.cfg # for convenience
|
43 |
+
######## process source portrait ########
|
44 |
+
img_rgb = load_image_rgb(args.source_image)
|
45 |
+
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
|
46 |
+
log(f"Load source image from {args.source_image}")
|
47 |
+
crop_info = self.cropper.crop_single_image(img_rgb)
|
48 |
+
source_lmk = crop_info['lmk_crop']
|
49 |
+
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
|
50 |
+
if inference_cfg.flag_do_crop:
|
51 |
+
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
|
52 |
+
else:
|
53 |
+
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
|
54 |
+
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
55 |
+
x_c_s = x_s_info['kp']
|
56 |
+
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
57 |
+
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
58 |
+
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
59 |
+
|
60 |
+
if inference_cfg.flag_lip_zero:
|
61 |
+
# let lip-open scalar to be 0 at first
|
62 |
+
c_d_lip_before_animation = [0.]
|
63 |
+
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
|
64 |
+
if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold:
|
65 |
+
inference_cfg.flag_lip_zero = False
|
66 |
+
else:
|
67 |
+
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
|
68 |
+
############################################
|
69 |
+
|
70 |
+
######## process driving info ########
|
71 |
+
if is_video(args.driving_info):
|
72 |
+
log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
|
73 |
+
# TODO: 这里track一下驱动视频 -> 构建模板
|
74 |
+
driving_rgb_lst = load_driving_info(args.driving_info)
|
75 |
+
driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
|
76 |
+
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256)
|
77 |
+
n_frames = I_d_lst.shape[0]
|
78 |
+
if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
|
79 |
+
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
|
80 |
+
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
|
81 |
+
elif is_template(args.driving_info):
|
82 |
+
log(f"Load from video templates {args.driving_info}")
|
83 |
+
with open(args.driving_info, 'rb') as f:
|
84 |
+
template_lst, driving_lmk_lst = pickle.load(f)
|
85 |
+
n_frames = template_lst[0]['n_frames']
|
86 |
+
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
|
87 |
+
else:
|
88 |
+
raise Exception("Unsupported driving types!")
|
89 |
+
#########################################
|
90 |
+
|
91 |
+
######## prepare for pasteback ########
|
92 |
+
if inference_cfg.flag_pasteback:
|
93 |
+
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
94 |
+
I_p_paste_lst = []
|
95 |
+
#########################################
|
96 |
+
|
97 |
+
I_p_lst = []
|
98 |
+
R_d_0, x_d_0_info = None, None
|
99 |
+
for i in track(range(n_frames), description='Animating...', total=n_frames):
|
100 |
+
if is_video(args.driving_info):
|
101 |
+
# extract kp info by M
|
102 |
+
I_d_i = I_d_lst[i]
|
103 |
+
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
|
104 |
+
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
|
105 |
+
else:
|
106 |
+
# from template
|
107 |
+
x_d_i_info = template_lst[i]
|
108 |
+
x_d_i_info = dct2cuda(x_d_i_info, inference_cfg.device_id)
|
109 |
+
R_d_i = x_d_i_info['R_d']
|
110 |
+
|
111 |
+
if i == 0:
|
112 |
+
R_d_0 = R_d_i
|
113 |
+
x_d_0_info = x_d_i_info
|
114 |
+
|
115 |
+
if inference_cfg.flag_relative:
|
116 |
+
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
117 |
+
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
118 |
+
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
119 |
+
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
120 |
+
else:
|
121 |
+
R_new = R_d_i
|
122 |
+
delta_new = x_d_i_info['exp']
|
123 |
+
scale_new = x_s_info['scale']
|
124 |
+
t_new = x_d_i_info['t']
|
125 |
+
|
126 |
+
t_new[..., 2].fill_(0) # zero tz
|
127 |
+
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
128 |
+
|
129 |
+
# Algorithm 1:
|
130 |
+
if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
|
131 |
+
# without stitching or retargeting
|
132 |
+
if inference_cfg.flag_lip_zero:
|
133 |
+
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
134 |
+
else:
|
135 |
+
pass
|
136 |
+
elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
|
137 |
+
# with stitching and without retargeting
|
138 |
+
if inference_cfg.flag_lip_zero:
|
139 |
+
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
140 |
+
else:
|
141 |
+
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
142 |
+
else:
|
143 |
+
eyes_delta, lip_delta = None, None
|
144 |
+
if inference_cfg.flag_eye_retargeting:
|
145 |
+
c_d_eyes_i = input_eye_ratio_lst[i]
|
146 |
+
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
|
147 |
+
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
148 |
+
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
|
149 |
+
if inference_cfg.flag_lip_retargeting:
|
150 |
+
c_d_lip_i = input_lip_ratio_lst[i]
|
151 |
+
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
|
152 |
+
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
153 |
+
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
|
154 |
+
|
155 |
+
if inference_cfg.flag_relative: # use x_s
|
156 |
+
x_d_i_new = x_s + \
|
157 |
+
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
158 |
+
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
159 |
+
else: # use x_d,i
|
160 |
+
x_d_i_new = x_d_i_new + \
|
161 |
+
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
162 |
+
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
163 |
+
|
164 |
+
if inference_cfg.flag_stitching:
|
165 |
+
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
166 |
+
|
167 |
+
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
|
168 |
+
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
169 |
+
I_p_lst.append(I_p_i)
|
170 |
+
|
171 |
+
if inference_cfg.flag_pasteback:
|
172 |
+
I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
|
173 |
+
I_p_paste_lst.append(I_p_i_to_ori_blend)
|
174 |
+
|
175 |
+
mkdir(args.output_dir)
|
176 |
+
wfp_concat = None
|
177 |
+
if is_video(args.driving_info):
|
178 |
+
frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
|
179 |
+
# save (driving frames, source image, drived frames) result
|
180 |
+
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
|
181 |
+
images2video(frames_concatenated, wfp=wfp_concat)
|
182 |
+
|
183 |
+
# save drived result
|
184 |
+
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
|
185 |
+
if inference_cfg.flag_pasteback:
|
186 |
+
images2video(I_p_paste_lst, wfp=wfp)
|
187 |
+
else:
|
188 |
+
images2video(I_p_lst, wfp=wfp)
|
189 |
+
|
190 |
+
return wfp, wfp_concat
|
src/live_portrait_wrapper.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Wrapper for LivePortrait core functions
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os.path as osp
|
8 |
+
import numpy as np
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
import yaml
|
12 |
+
|
13 |
+
from .utils.timer import Timer
|
14 |
+
from .utils.helper import load_model, concat_feat
|
15 |
+
from .utils.camera import headpose_pred_to_degree, get_rotation_matrix
|
16 |
+
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
17 |
+
from .config.inference_config import InferenceConfig
|
18 |
+
from .utils.rprint import rlog as log
|
19 |
+
|
20 |
+
|
21 |
+
class LivePortraitWrapper(object):
|
22 |
+
|
23 |
+
def __init__(self, cfg: InferenceConfig):
|
24 |
+
|
25 |
+
model_config = yaml.load(open(cfg.models_config, 'r'), Loader=yaml.SafeLoader)
|
26 |
+
|
27 |
+
# init F
|
28 |
+
self.appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
|
29 |
+
log(f'Load appearance_feature_extractor done.')
|
30 |
+
# init M
|
31 |
+
self.motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
|
32 |
+
log(f'Load motion_extractor done.')
|
33 |
+
# init W
|
34 |
+
self.warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
|
35 |
+
log(f'Load warping_module done.')
|
36 |
+
# init G
|
37 |
+
self.spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
|
38 |
+
log(f'Load spade_generator done.')
|
39 |
+
# init S and R
|
40 |
+
if cfg.checkpoint_S is not None and osp.exists(cfg.checkpoint_S):
|
41 |
+
self.stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
|
42 |
+
log(f'Load stitching_retargeting_module done.')
|
43 |
+
else:
|
44 |
+
self.stitching_retargeting_module = None
|
45 |
+
|
46 |
+
self.cfg = cfg
|
47 |
+
self.device_id = cfg.device_id
|
48 |
+
self.timer = Timer()
|
49 |
+
|
50 |
+
def update_config(self, user_args):
|
51 |
+
for k, v in user_args.items():
|
52 |
+
if hasattr(self.cfg, k):
|
53 |
+
setattr(self.cfg, k, v)
|
54 |
+
|
55 |
+
def prepare_source(self, img: np.ndarray) -> torch.Tensor:
|
56 |
+
""" construct the input as standard
|
57 |
+
img: HxWx3, uint8, 256x256
|
58 |
+
"""
|
59 |
+
h, w = img.shape[:2]
|
60 |
+
if h != self.cfg.input_shape[0] or w != self.cfg.input_shape[1]:
|
61 |
+
x = cv2.resize(img, (self.cfg.input_shape[0], self.cfg.input_shape[1]))
|
62 |
+
else:
|
63 |
+
x = img.copy()
|
64 |
+
|
65 |
+
if x.ndim == 3:
|
66 |
+
x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
|
67 |
+
elif x.ndim == 4:
|
68 |
+
x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
|
69 |
+
else:
|
70 |
+
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
|
71 |
+
x = np.clip(x, 0, 1) # clip to 0~1
|
72 |
+
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
|
73 |
+
x = x.cuda(self.device_id)
|
74 |
+
return x
|
75 |
+
|
76 |
+
def prepare_driving_videos(self, imgs) -> torch.Tensor:
|
77 |
+
""" construct the input as standard
|
78 |
+
imgs: NxBxHxWx3, uint8
|
79 |
+
"""
|
80 |
+
if isinstance(imgs, list):
|
81 |
+
_imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
|
82 |
+
elif isinstance(imgs, np.ndarray):
|
83 |
+
_imgs = imgs
|
84 |
+
else:
|
85 |
+
raise ValueError(f'imgs type error: {type(imgs)}')
|
86 |
+
|
87 |
+
y = _imgs.astype(np.float32) / 255.
|
88 |
+
y = np.clip(y, 0, 1) # clip to 0~1
|
89 |
+
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
|
90 |
+
y = y.cuda(self.device_id)
|
91 |
+
|
92 |
+
return y
|
93 |
+
|
94 |
+
def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
|
95 |
+
""" get the appearance feature of the image by F
|
96 |
+
x: Bx3xHxW, normalized to 0~1
|
97 |
+
"""
|
98 |
+
with torch.no_grad():
|
99 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
|
100 |
+
feature_3d = self.appearance_feature_extractor(x)
|
101 |
+
|
102 |
+
return feature_3d.float()
|
103 |
+
|
104 |
+
def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
|
105 |
+
""" get the implicit keypoint information
|
106 |
+
x: Bx3xHxW, normalized to 0~1
|
107 |
+
flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
|
108 |
+
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
|
109 |
+
"""
|
110 |
+
with torch.no_grad():
|
111 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
|
112 |
+
kp_info = self.motion_extractor(x)
|
113 |
+
|
114 |
+
if self.cfg.flag_use_half_precision:
|
115 |
+
# float the dict
|
116 |
+
for k, v in kp_info.items():
|
117 |
+
if isinstance(v, torch.Tensor):
|
118 |
+
kp_info[k] = v.float()
|
119 |
+
|
120 |
+
flag_refine_info: bool = kwargs.get('flag_refine_info', True)
|
121 |
+
if flag_refine_info:
|
122 |
+
bs = kp_info['kp'].shape[0]
|
123 |
+
kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1
|
124 |
+
kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1
|
125 |
+
kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1
|
126 |
+
kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3
|
127 |
+
kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3
|
128 |
+
|
129 |
+
return kp_info
|
130 |
+
|
131 |
+
def get_pose_dct(self, kp_info: dict) -> dict:
|
132 |
+
pose_dct = dict(
|
133 |
+
pitch=headpose_pred_to_degree(kp_info['pitch']).item(),
|
134 |
+
yaw=headpose_pred_to_degree(kp_info['yaw']).item(),
|
135 |
+
roll=headpose_pred_to_degree(kp_info['roll']).item(),
|
136 |
+
)
|
137 |
+
return pose_dct
|
138 |
+
|
139 |
+
def get_fs_and_kp_info(self, source_prepared, driving_first_frame):
|
140 |
+
|
141 |
+
# get the canonical keypoints of source image by M
|
142 |
+
source_kp_info = self.get_kp_info(source_prepared, flag_refine_info=True)
|
143 |
+
source_rotation = get_rotation_matrix(source_kp_info['pitch'], source_kp_info['yaw'], source_kp_info['roll'])
|
144 |
+
|
145 |
+
# get the canonical keypoints of first driving frame by M
|
146 |
+
driving_first_frame_kp_info = self.get_kp_info(driving_first_frame, flag_refine_info=True)
|
147 |
+
driving_first_frame_rotation = get_rotation_matrix(
|
148 |
+
driving_first_frame_kp_info['pitch'],
|
149 |
+
driving_first_frame_kp_info['yaw'],
|
150 |
+
driving_first_frame_kp_info['roll']
|
151 |
+
)
|
152 |
+
|
153 |
+
# get feature volume by F
|
154 |
+
source_feature_3d = self.extract_feature_3d(source_prepared)
|
155 |
+
|
156 |
+
return source_kp_info, source_rotation, source_feature_3d, driving_first_frame_kp_info, driving_first_frame_rotation
|
157 |
+
|
158 |
+
def transform_keypoint(self, kp_info: dict):
|
159 |
+
"""
|
160 |
+
transform the implicit keypoints with the pose, shift, and expression deformation
|
161 |
+
kp: BxNx3
|
162 |
+
"""
|
163 |
+
kp = kp_info['kp'] # (bs, k, 3)
|
164 |
+
pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
|
165 |
+
|
166 |
+
t, exp = kp_info['t'], kp_info['exp']
|
167 |
+
scale = kp_info['scale']
|
168 |
+
|
169 |
+
pitch = headpose_pred_to_degree(pitch)
|
170 |
+
yaw = headpose_pred_to_degree(yaw)
|
171 |
+
roll = headpose_pred_to_degree(roll)
|
172 |
+
|
173 |
+
bs = kp.shape[0]
|
174 |
+
if kp.ndim == 2:
|
175 |
+
num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
|
176 |
+
else:
|
177 |
+
num_kp = kp.shape[1] # Bxnum_kpx3
|
178 |
+
|
179 |
+
rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
|
180 |
+
|
181 |
+
# Eqn.2: s * (R * x_c,s + exp) + t
|
182 |
+
kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
|
183 |
+
kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
|
184 |
+
kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
|
185 |
+
|
186 |
+
return kp_transformed
|
187 |
+
|
188 |
+
def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor:
|
189 |
+
"""
|
190 |
+
kp_source: BxNx3
|
191 |
+
eye_close_ratio: Bx3
|
192 |
+
Return: Bx(3*num_kp+2)
|
193 |
+
"""
|
194 |
+
feat_eye = concat_feat(kp_source, eye_close_ratio)
|
195 |
+
|
196 |
+
with torch.no_grad():
|
197 |
+
delta = self.stitching_retargeting_module['eye'](feat_eye)
|
198 |
+
|
199 |
+
return delta
|
200 |
+
|
201 |
+
def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
|
202 |
+
"""
|
203 |
+
kp_source: BxNx3
|
204 |
+
lip_close_ratio: Bx2
|
205 |
+
"""
|
206 |
+
feat_lip = concat_feat(kp_source, lip_close_ratio)
|
207 |
+
|
208 |
+
with torch.no_grad():
|
209 |
+
delta = self.stitching_retargeting_module['lip'](feat_lip)
|
210 |
+
|
211 |
+
return delta
|
212 |
+
|
213 |
+
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
214 |
+
"""
|
215 |
+
kp_source: BxNx3
|
216 |
+
kp_driving: BxNx3
|
217 |
+
Return: Bx(3*num_kp+2)
|
218 |
+
"""
|
219 |
+
feat_stiching = concat_feat(kp_source, kp_driving)
|
220 |
+
|
221 |
+
with torch.no_grad():
|
222 |
+
delta = self.stitching_retargeting_module['stitching'](feat_stiching)
|
223 |
+
|
224 |
+
return delta
|
225 |
+
|
226 |
+
def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
227 |
+
""" conduct the stitching
|
228 |
+
kp_source: Bxnum_kpx3
|
229 |
+
kp_driving: Bxnum_kpx3
|
230 |
+
"""
|
231 |
+
|
232 |
+
if self.stitching_retargeting_module is not None:
|
233 |
+
|
234 |
+
bs, num_kp = kp_source.shape[:2]
|
235 |
+
|
236 |
+
kp_driving_new = kp_driving.clone()
|
237 |
+
delta = self.stitch(kp_source, kp_driving_new)
|
238 |
+
|
239 |
+
delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
|
240 |
+
delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
|
241 |
+
|
242 |
+
kp_driving_new += delta_exp
|
243 |
+
kp_driving_new[..., :2] += delta_tx_ty
|
244 |
+
|
245 |
+
return kp_driving_new
|
246 |
+
|
247 |
+
return kp_driving
|
248 |
+
|
249 |
+
def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
250 |
+
""" get the image after the warping of the implicit keypoints
|
251 |
+
feature_3d: Bx32x16x64x64, feature volume
|
252 |
+
kp_source: BxNx3
|
253 |
+
kp_driving: BxNx3
|
254 |
+
"""
|
255 |
+
# The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
|
256 |
+
with torch.no_grad():
|
257 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
|
258 |
+
# get decoder input
|
259 |
+
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
|
260 |
+
# decode
|
261 |
+
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
|
262 |
+
|
263 |
+
# float the dict
|
264 |
+
if self.cfg.flag_use_half_precision:
|
265 |
+
for k, v in ret_dct.items():
|
266 |
+
if isinstance(v, torch.Tensor):
|
267 |
+
ret_dct[k] = v.float()
|
268 |
+
|
269 |
+
return ret_dct
|
270 |
+
|
271 |
+
def parse_output(self, out: torch.Tensor) -> np.ndarray:
|
272 |
+
""" construct the output as standard
|
273 |
+
return: 1xHxWx3, uint8
|
274 |
+
"""
|
275 |
+
out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3
|
276 |
+
out = np.clip(out, 0, 1) # clip to 0~1
|
277 |
+
out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255
|
278 |
+
|
279 |
+
return out
|
280 |
+
|
281 |
+
def calc_retargeting_ratio(self, source_lmk, driving_lmk_lst):
|
282 |
+
input_eye_ratio_lst = []
|
283 |
+
input_lip_ratio_lst = []
|
284 |
+
for lmk in driving_lmk_lst:
|
285 |
+
# for eyes retargeting
|
286 |
+
input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
|
287 |
+
# for lip retargeting
|
288 |
+
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
|
289 |
+
return input_eye_ratio_lst, input_lip_ratio_lst
|
290 |
+
|
291 |
+
def calc_combined_eye_ratio(self, input_eye_ratio, source_lmk):
|
292 |
+
eye_close_ratio = calc_eye_close_ratio(source_lmk[None])
|
293 |
+
eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(self.device_id)
|
294 |
+
input_eye_ratio_tensor = torch.Tensor([input_eye_ratio[0][0]]).reshape(1, 1).cuda(self.device_id)
|
295 |
+
# [c_s,eyes, c_d,eyes,i]
|
296 |
+
combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1)
|
297 |
+
return combined_eye_ratio_tensor
|
298 |
+
|
299 |
+
def calc_combined_lip_ratio(self, input_lip_ratio, source_lmk):
|
300 |
+
lip_close_ratio = calc_lip_close_ratio(source_lmk[None])
|
301 |
+
lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(self.device_id)
|
302 |
+
# [c_s,lip, c_d,lip,i]
|
303 |
+
input_lip_ratio_tensor = torch.Tensor([input_lip_ratio[0]]).cuda(self.device_id)
|
304 |
+
if input_lip_ratio_tensor.shape != [1, 1]:
|
305 |
+
input_lip_ratio_tensor = input_lip_ratio_tensor.reshape(1, 1)
|
306 |
+
combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
|
307 |
+
return combined_lip_ratio_tensor
|
src/modules/__init__.py
ADDED
File without changes
|
src/modules/appearance_feature_extractor.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from .util import SameBlock2d, DownBlock2d, ResBlock3d
|
10 |
+
|
11 |
+
|
12 |
+
class AppearanceFeatureExtractor(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, image_channel, block_expansion, num_down_blocks, max_features, reshape_channel, reshape_depth, num_resblocks):
|
15 |
+
super(AppearanceFeatureExtractor, self).__init__()
|
16 |
+
self.image_channel = image_channel
|
17 |
+
self.block_expansion = block_expansion
|
18 |
+
self.num_down_blocks = num_down_blocks
|
19 |
+
self.max_features = max_features
|
20 |
+
self.reshape_channel = reshape_channel
|
21 |
+
self.reshape_depth = reshape_depth
|
22 |
+
|
23 |
+
self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1))
|
24 |
+
|
25 |
+
down_blocks = []
|
26 |
+
for i in range(num_down_blocks):
|
27 |
+
in_features = min(max_features, block_expansion * (2 ** i))
|
28 |
+
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
29 |
+
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
30 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
31 |
+
|
32 |
+
self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
|
33 |
+
|
34 |
+
self.resblocks_3d = torch.nn.Sequential()
|
35 |
+
for i in range(num_resblocks):
|
36 |
+
self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
|
37 |
+
|
38 |
+
def forward(self, source_image):
|
39 |
+
out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256
|
40 |
+
|
41 |
+
for i in range(len(self.down_blocks)):
|
42 |
+
out = self.down_blocks[i](out)
|
43 |
+
out = self.second(out)
|
44 |
+
bs, c, h, w = out.shape # ->Bx512x64x64
|
45 |
+
|
46 |
+
f_s = out.view(bs, self.reshape_channel, self.reshape_depth, h, w) # ->Bx32x16x64x64
|
47 |
+
f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64
|
48 |
+
return f_s
|
src/modules/convnextv2.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
# from timm.models.layers import trunc_normal_, DropPath
|
10 |
+
from .util import LayerNorm, DropPath, trunc_normal_, GRN
|
11 |
+
|
12 |
+
__all__ = ['convnextv2_tiny']
|
13 |
+
|
14 |
+
|
15 |
+
class Block(nn.Module):
|
16 |
+
""" ConvNeXtV2 Block.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
dim (int): Number of input channels.
|
20 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, dim, drop_path=0.):
|
24 |
+
super().__init__()
|
25 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
26 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
27 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
28 |
+
self.act = nn.GELU()
|
29 |
+
self.grn = GRN(4 * dim)
|
30 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
31 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
input = x
|
35 |
+
x = self.dwconv(x)
|
36 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
37 |
+
x = self.norm(x)
|
38 |
+
x = self.pwconv1(x)
|
39 |
+
x = self.act(x)
|
40 |
+
x = self.grn(x)
|
41 |
+
x = self.pwconv2(x)
|
42 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
43 |
+
|
44 |
+
x = input + self.drop_path(x)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class ConvNeXtV2(nn.Module):
|
49 |
+
""" ConvNeXt V2
|
50 |
+
|
51 |
+
Args:
|
52 |
+
in_chans (int): Number of input image channels. Default: 3
|
53 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
54 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
55 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
56 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
57 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
in_chans=3,
|
63 |
+
depths=[3, 3, 9, 3],
|
64 |
+
dims=[96, 192, 384, 768],
|
65 |
+
drop_path_rate=0.,
|
66 |
+
**kwargs
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
self.depths = depths
|
70 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
71 |
+
stem = nn.Sequential(
|
72 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
73 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
74 |
+
)
|
75 |
+
self.downsample_layers.append(stem)
|
76 |
+
for i in range(3):
|
77 |
+
downsample_layer = nn.Sequential(
|
78 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
79 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
80 |
+
)
|
81 |
+
self.downsample_layers.append(downsample_layer)
|
82 |
+
|
83 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
84 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
85 |
+
cur = 0
|
86 |
+
for i in range(4):
|
87 |
+
stage = nn.Sequential(
|
88 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
|
89 |
+
)
|
90 |
+
self.stages.append(stage)
|
91 |
+
cur += depths[i]
|
92 |
+
|
93 |
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
|
94 |
+
|
95 |
+
# NOTE: the output semantic items
|
96 |
+
num_bins = kwargs.get('num_bins', 66)
|
97 |
+
num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints
|
98 |
+
self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints
|
99 |
+
|
100 |
+
# print('dims[-1]: ', dims[-1])
|
101 |
+
self.fc_scale = nn.Linear(dims[-1], 1) # scale
|
102 |
+
self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins
|
103 |
+
self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins
|
104 |
+
self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins
|
105 |
+
self.fc_t = nn.Linear(dims[-1], 3) # translation
|
106 |
+
self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta
|
107 |
+
|
108 |
+
def _init_weights(self, m):
|
109 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
110 |
+
trunc_normal_(m.weight, std=.02)
|
111 |
+
nn.init.constant_(m.bias, 0)
|
112 |
+
|
113 |
+
def forward_features(self, x):
|
114 |
+
for i in range(4):
|
115 |
+
x = self.downsample_layers[i](x)
|
116 |
+
x = self.stages[i](x)
|
117 |
+
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x = self.forward_features(x)
|
121 |
+
|
122 |
+
# implicit keypoints
|
123 |
+
kp = self.fc_kp(x)
|
124 |
+
|
125 |
+
# pose and expression deformation
|
126 |
+
pitch = self.fc_pitch(x)
|
127 |
+
yaw = self.fc_yaw(x)
|
128 |
+
roll = self.fc_roll(x)
|
129 |
+
t = self.fc_t(x)
|
130 |
+
exp = self.fc_exp(x)
|
131 |
+
scale = self.fc_scale(x)
|
132 |
+
|
133 |
+
ret_dct = {
|
134 |
+
'pitch': pitch,
|
135 |
+
'yaw': yaw,
|
136 |
+
'roll': roll,
|
137 |
+
't': t,
|
138 |
+
'exp': exp,
|
139 |
+
'scale': scale,
|
140 |
+
|
141 |
+
'kp': kp, # canonical keypoint
|
142 |
+
}
|
143 |
+
|
144 |
+
return ret_dct
|
145 |
+
|
146 |
+
|
147 |
+
def convnextv2_tiny(**kwargs):
|
148 |
+
model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
|
149 |
+
return model
|
src/modules/dense_motion.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
The module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
|
5 |
+
"""
|
6 |
+
|
7 |
+
from torch import nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch
|
10 |
+
from .util import Hourglass, make_coordinate_grid, kp2gaussian
|
11 |
+
|
12 |
+
|
13 |
+
class DenseMotionNetwork(nn.Module):
|
14 |
+
def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=True):
|
15 |
+
super(DenseMotionNetwork, self).__init__()
|
16 |
+
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) # ~60+G
|
17 |
+
|
18 |
+
self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) # 65G! NOTE: computation cost is large
|
19 |
+
self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) # 0.8G
|
20 |
+
self.norm = nn.BatchNorm3d(compress, affine=True)
|
21 |
+
self.num_kp = num_kp
|
22 |
+
self.flag_estimate_occlusion_map = estimate_occlusion_map
|
23 |
+
|
24 |
+
if self.flag_estimate_occlusion_map:
|
25 |
+
self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
|
26 |
+
else:
|
27 |
+
self.occlusion = None
|
28 |
+
|
29 |
+
def create_sparse_motions(self, feature, kp_driving, kp_source):
|
30 |
+
bs, _, d, h, w = feature.shape # (bs, 4, 16, 64, 64)
|
31 |
+
identity_grid = make_coordinate_grid((d, h, w), ref=kp_source) # (16, 64, 64, 3)
|
32 |
+
identity_grid = identity_grid.view(1, 1, d, h, w, 3) # (1, 1, d=16, h=64, w=64, 3)
|
33 |
+
coordinate_grid = identity_grid - kp_driving.view(bs, self.num_kp, 1, 1, 1, 3)
|
34 |
+
|
35 |
+
k = coordinate_grid.shape[1]
|
36 |
+
|
37 |
+
# NOTE: there lacks an one-order flow
|
38 |
+
driving_to_source = coordinate_grid + kp_source.view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
|
39 |
+
|
40 |
+
# adding background feature
|
41 |
+
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
|
42 |
+
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) # (bs, 1+num_kp, d, h, w, 3)
|
43 |
+
return sparse_motions
|
44 |
+
|
45 |
+
def create_deformed_feature(self, feature, sparse_motions):
|
46 |
+
bs, _, d, h, w = feature.shape
|
47 |
+
feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
|
48 |
+
feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
|
49 |
+
sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3)
|
50 |
+
sparse_deformed = F.grid_sample(feature_repeat, sparse_motions, align_corners=False)
|
51 |
+
sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
|
52 |
+
|
53 |
+
return sparse_deformed
|
54 |
+
|
55 |
+
def create_heatmap_representations(self, feature, kp_driving, kp_source):
|
56 |
+
spatial_size = feature.shape[3:] # (d=16, h=64, w=64)
|
57 |
+
gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
|
58 |
+
gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
|
59 |
+
heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w)
|
60 |
+
|
61 |
+
# adding background feature
|
62 |
+
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()).to(heatmap.device)
|
63 |
+
heatmap = torch.cat([zeros, heatmap], dim=1)
|
64 |
+
heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w)
|
65 |
+
return heatmap
|
66 |
+
|
67 |
+
def forward(self, feature, kp_driving, kp_source):
|
68 |
+
bs, _, d, h, w = feature.shape # (bs, 32, 16, 64, 64)
|
69 |
+
|
70 |
+
feature = self.compress(feature) # (bs, 4, 16, 64, 64)
|
71 |
+
feature = self.norm(feature) # (bs, 4, 16, 64, 64)
|
72 |
+
feature = F.relu(feature) # (bs, 4, 16, 64, 64)
|
73 |
+
|
74 |
+
out_dict = dict()
|
75 |
+
|
76 |
+
# 1. deform 3d feature
|
77 |
+
sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) # (bs, 1+num_kp, d, h, w, 3)
|
78 |
+
deformed_feature = self.create_deformed_feature(feature, sparse_motion) # (bs, 1+num_kp, c=4, d=16, h=64, w=64)
|
79 |
+
|
80 |
+
# 2. (bs, 1+num_kp, d, h, w)
|
81 |
+
heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) # (bs, 1+num_kp, 1, d, h, w)
|
82 |
+
|
83 |
+
input = torch.cat([heatmap, deformed_feature], dim=2) # (bs, 1+num_kp, c=5, d=16, h=64, w=64)
|
84 |
+
input = input.view(bs, -1, d, h, w) # (bs, (1+num_kp)*c=105, d=16, h=64, w=64)
|
85 |
+
|
86 |
+
prediction = self.hourglass(input)
|
87 |
+
|
88 |
+
mask = self.mask(prediction)
|
89 |
+
mask = F.softmax(mask, dim=1) # (bs, 1+num_kp, d=16, h=64, w=64)
|
90 |
+
out_dict['mask'] = mask
|
91 |
+
mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
|
92 |
+
sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
|
93 |
+
deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) mask take effect in this place
|
94 |
+
deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
|
95 |
+
|
96 |
+
out_dict['deformation'] = deformation
|
97 |
+
|
98 |
+
if self.flag_estimate_occlusion_map:
|
99 |
+
bs, _, d, h, w = prediction.shape
|
100 |
+
prediction_reshape = prediction.view(bs, -1, h, w)
|
101 |
+
occlusion_map = torch.sigmoid(self.occlusion(prediction_reshape)) # Bx1x64x64
|
102 |
+
out_dict['occlusion_map'] = occlusion_map
|
103 |
+
|
104 |
+
return out_dict
|