|
|
|
|
|
import os.path as osp
|
|
import tyro
|
|
from src.config.argument_config import ArgumentConfig
|
|
from src.config.inference_config import InferenceConfig
|
|
from src.config.crop_config import CropConfig
|
|
from src.live_portrait_pipeline import LivePortraitPipeline
|
|
|
|
|
|
def partial_fields(target_class, kwargs):
|
|
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
|
|
|
|
|
def fast_check_args(args: ArgumentConfig):
|
|
if not osp.exists(args.source_image):
|
|
raise FileNotFoundError(f"source image not found: {args.source_image}")
|
|
if not osp.exists(args.driving_info):
|
|
raise FileNotFoundError(f"driving info not found: {args.driving_info}")
|
|
|
|
|
|
def main():
|
|
|
|
tyro.extras.set_accent_color("bright_cyan")
|
|
args = tyro.cli(ArgumentConfig)
|
|
|
|
|
|
fast_check_args(args)
|
|
|
|
|
|
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
|
|
crop_cfg = partial_fields(CropConfig, args.__dict__)
|
|
|
|
live_portrait_pipeline = LivePortraitPipeline(
|
|
inference_cfg=inference_cfg,
|
|
crop_cfg=crop_cfg
|
|
)
|
|
|
|
|
|
if args.flag_svideo:
|
|
live_portrait_pipeline.execute_source_video(args)
|
|
|
|
else:
|
|
live_portrait_pipeline.execute(args)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|