|  |  | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | config dataclass used for inference | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | import os.path as osp | 
					
						
						|  | import cv2 | 
					
						
						|  | from numpy import ndarray | 
					
						
						|  | from dataclasses import dataclass | 
					
						
						|  | from typing import Literal, Tuple | 
					
						
						|  | from .base_config import PrintableConfig, make_abs_path | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass(repr=False) | 
					
						
						|  | class InferenceConfig(PrintableConfig): | 
					
						
						|  | models_config: str = make_abs_path('./models.yaml') | 
					
						
						|  | checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') | 
					
						
						|  | checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') | 
					
						
						|  | checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') | 
					
						
						|  | checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') | 
					
						
						|  |  | 
					
						
						|  | checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') | 
					
						
						|  | flag_use_half_precision: bool = True | 
					
						
						|  |  | 
					
						
						|  | flag_lip_zero: bool = True | 
					
						
						|  | lip_zero_threshold: float = 0.03 | 
					
						
						|  |  | 
					
						
						|  | flag_eye_retargeting: bool = False | 
					
						
						|  | flag_lip_retargeting: bool = False | 
					
						
						|  | flag_stitching: bool = True | 
					
						
						|  |  | 
					
						
						|  | flag_relative: bool = True | 
					
						
						|  | anchor_frame: int = 0 | 
					
						
						|  |  | 
					
						
						|  | input_shape: Tuple[int, int] = (256, 256) | 
					
						
						|  | output_format: Literal['mp4', 'gif'] = 'mp4' | 
					
						
						|  | output_fps: int = 30 | 
					
						
						|  | crf: int = 15 | 
					
						
						|  |  | 
					
						
						|  | flag_write_result: bool = True | 
					
						
						|  | flag_pasteback: bool = True | 
					
						
						|  | mask_crop: ndarray = cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR) | 
					
						
						|  | flag_write_gif: bool = False | 
					
						
						|  | size_gif: int = 256 | 
					
						
						|  | ref_max_shape: int = 1280 | 
					
						
						|  | ref_shape_n: int = 2 | 
					
						
						|  |  | 
					
						
						|  | device_id: int = 0 | 
					
						
						|  | flag_do_crop: bool = False | 
					
						
						|  | flag_do_rot: bool = True | 
					
						
						|  |  |