Spaces:
Paused
Paused
File size: 7,856 Bytes
6764da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
"""run bash scripts/download_models.sh first to prepare the weights file"""
import os
import shutil
from argparse import Namespace
from src.utils.preprocess import CropAndExtract
from src.test_audio2coeff import Audio2Coeff
from src.facerender.animate import AnimateFromCoeff
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
from cog import BasePredictor, Input, Path
checkpoints = "checkpoints"
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
device = "cuda"
path_of_lm_croper = os.path.join(
checkpoints, "shape_predictor_68_face_landmarks.dat"
)
path_of_net_recon_model = os.path.join(checkpoints, "epoch_20.pth")
dir_of_BFM_fitting = os.path.join(checkpoints, "BFM_Fitting")
wav2lip_checkpoint = os.path.join(checkpoints, "wav2lip.pth")
audio2pose_checkpoint = os.path.join(checkpoints, "auido2pose_00140-model.pth")
audio2pose_yaml_path = os.path.join("src", "config", "auido2pose.yaml")
audio2exp_checkpoint = os.path.join(checkpoints, "auido2exp_00300-model.pth")
audio2exp_yaml_path = os.path.join("src", "config", "auido2exp.yaml")
free_view_checkpoint = os.path.join(
checkpoints, "facevid2vid_00189-model.pth.tar"
)
# init model
self.preprocess_model = CropAndExtract(
path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device
)
self.audio_to_coeff = Audio2Coeff(
audio2pose_checkpoint,
audio2pose_yaml_path,
audio2exp_checkpoint,
audio2exp_yaml_path,
wav2lip_checkpoint,
device,
)
self.animate_from_coeff = {
"full": AnimateFromCoeff(
free_view_checkpoint,
os.path.join(checkpoints, "mapping_00109-model.pth.tar"),
os.path.join("src", "config", "facerender_still.yaml"),
device,
),
"others": AnimateFromCoeff(
free_view_checkpoint,
os.path.join(checkpoints, "mapping_00229-model.pth.tar"),
os.path.join("src", "config", "facerender.yaml"),
device,
),
}
def predict(
self,
source_image: Path = Input(
description="Upload the source image, it can be video.mp4 or picture.png",
),
driven_audio: Path = Input(
description="Upload the driven audio, accepts .wav and .mp4 file",
),
enhancer: str = Input(
description="Choose a face enhancer",
choices=["gfpgan", "RestoreFormer"],
default="gfpgan",
),
preprocess: str = Input(
description="how to preprocess the images",
choices=["crop", "resize", "full"],
default="full",
),
ref_eyeblink: Path = Input(
description="path to reference video providing eye blinking",
default=None,
),
ref_pose: Path = Input(
description="path to reference video providing pose",
default=None,
),
still: bool = Input(
description="can crop back to the original videos for the full body aniamtion when preprocess is full",
default=True,
),
) -> Path:
"""Run a single prediction on the model"""
animate_from_coeff = (
self.animate_from_coeff["full"]
if preprocess == "full"
else self.animate_from_coeff["others"]
)
args = load_default()
args.pic_path = str(source_image)
args.audio_path = str(driven_audio)
device = "cuda"
args.still = still
args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink)
args.ref_pose = None if ref_pose is None else str(ref_pose)
# crop image and extract 3dmm from image
results_dir = "results"
if os.path.exists(results_dir):
shutil.rmtree(results_dir)
os.makedirs(results_dir)
first_frame_dir = os.path.join(results_dir, "first_frame_dir")
os.makedirs(first_frame_dir)
print("3DMM Extraction for source image")
first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
args.pic_path, first_frame_dir, preprocess, source_image_flag=True
)
if first_coeff_path is None:
print("Can't get the coeffs of the input")
return
if ref_eyeblink is not None:
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[
0
]
ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname)
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
print("3DMM Extraction for the reference video providing eye blinking")
ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate(
ref_eyeblink, ref_eyeblink_frame_dir
)
else:
ref_eyeblink_coeff_path = None
if ref_pose is not None:
if ref_pose == ref_eyeblink:
ref_pose_coeff_path = ref_eyeblink_coeff_path
else:
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname)
os.makedirs(ref_pose_frame_dir, exist_ok=True)
print("3DMM Extraction for the reference video providing pose")
ref_pose_coeff_path, _, _ = self.preprocess_model.generate(
ref_pose, ref_pose_frame_dir
)
else:
ref_pose_coeff_path = None
# audio2ceoff
batch = get_data(
first_coeff_path,
args.audio_path,
device,
ref_eyeblink_coeff_path,
still=still,
)
coeff_path = self.audio_to_coeff.generate(
batch, results_dir, args.pose_style, ref_pose_coeff_path
)
# coeff2video
print("coeff2video")
data = get_facerender_data(
coeff_path,
crop_pic_path,
first_coeff_path,
args.audio_path,
args.batch_size,
args.input_yaw,
args.input_pitch,
args.input_roll,
expression_scale=args.expression_scale,
still_mode=still,
preprocess=preprocess,
)
animate_from_coeff.generate(
data, results_dir, args.pic_path, crop_info,
enhancer=enhancer, background_enhancer=args.background_enhancer,
preprocess=preprocess)
output = "/tmp/out.mp4"
mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0])
shutil.copy(mp4_path, output)
return Path(output)
def load_default():
return Namespace(
pose_style=0,
batch_size=2,
expression_scale=1.0,
input_yaw=None,
input_pitch=None,
input_roll=None,
background_enhancer=None,
face3dvis=False,
net_recon="resnet50",
init_path=None,
use_last_fc=False,
bfm_folder="./checkpoints/BFM_Fitting/",
bfm_model="BFM_model_front.mat",
focal=1015.0,
center=112.0,
camera_d=10.0,
z_near=5.0,
z_far=15.0,
)
|