# encoding = 'utf-8' import os import os.path as osp import sys from omegaconf import OmegaConf import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) import torch torch.backends.cudnn.benchmark = True # disable CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR warning sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))))) from src.datasets.preprocess.extract_features.audio_processer import AudioProcessor from src.datasets.preprocess.extract_features.motion_processer import MotionProcesser from src.models.dit.talking_head_diffusion import MotionDiffusion from src.utils.rprint import rlog as log import time emo_map = { 0: 'Anger', 1: 'Contempt', 2: 'Disgust', 3: 'Fear', 4: 'Happiness', 5: 'Neutral', 6: 'Sadness', 7: 'Surprise', 8: 'None' } # import torch import random import numpy as np def set_seed(seed: int = 42): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # 如果使用多个 GPU np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # 关闭 CuDNN 优化以保证可复现性 # 在推理前调用 set_seed(42) class NullableArgs: def __init__(self, namespace): for key, value in namespace.__dict__.items(): setattr(self, key, value) class LiveVASAPipeline(object): def __init__(self, cfg_path: str, load_motion_generator: bool = True, motion_mean_std_path=None): """The pipeline for LiveVASA The pipeline for LiveVASA Args: cfg_path (str): YAML config file path of LiveVASA """ # pretrained encoders of live portrait cfg = OmegaConf.load(cfg_path) self.device_id = cfg.device_id self.device = f"cuda:{self.device_id}" # 1 load audio processor self.audio_processor: AudioProcessor = AudioProcessor(cfg_path=cfg.audio_model_config, is_training=False) log(f"Load audio_processor done.") if cfg.motion_models_config is not None and load_motion_generator: motion_models_config = OmegaConf.load(cfg.motion_models_config) log(f"Load motion_models_config from {osp.realpath(cfg.motion_models_config)} done.") self.motion_generator = MotionDiffusion(motion_models_config, device=self.device) self.load_motion_generator(self.motion_generator, cfg.motion_generator_path) # self.motion_generator.eval() else: self.motion_generator = None log(f"Init motion_generator as None.") # 3. load motion processer self.motion_processer: MotionProcesser = MotionProcesser(cfg_path=cfg.motion_processer_config, device_id=cfg.device_id) log(f"Load motion_processor done.") self.motion_mean_std = None if motion_mean_std_path is not None: self.motion_mean_std = torch.load(motion_mean_std_path) self.motion_mean_std["mean"] = self.motion_mean_std["mean"].to(self.device) self.motion_mean_std["std"] = self.motion_mean_std["std"].to(self.device) print(f"scale mean: {self.motion_mean_std['mean'][0, 63:64]}, std: {self.motion_mean_std['std'][0, 63:64]}") print(f"t mean: {self.motion_mean_std['mean'][0, 64:67]}, std: {self.motion_mean_std['std'][0, 64:67]}") print(f"pitch mean: {self.motion_mean_std['mean'][0, 67:68]}, std: {self.motion_mean_std['std'][0, 67:68]}") print(f"yaw mean: {self.motion_mean_std['mean'][0, 68:69]}, std: {self.motion_mean_std['std'][0, 68:69]}") print(f"scoll mean: {self.motion_mean_std['mean'][0, 69:70]}, std: {self.motion_mean_std['std'][0, 69:70]}") self.cfg = cfg def set_motion_generator(self, motion_generator: MotionDiffusion): self.motion_generator = motion_generator self.motion_generator.to(self.device) def load_motion_generator(self, model, motion_generator_path: str): print(motion_generator_path) model_data = torch.load(motion_generator_path, map_location=self.device) model.load_state_dict(model_data, strict=False) model.to(self.device) model.eval() def modulate_lip(self, standard_motion: torch.Tensor, motions: torch.Tensor, alpha=5, beta=0.1): # standard_motion: 63 # motions: Tx63 standard_exp = standard_motion[:63].reshape(1, 21, 3) exps = motions[:, :63].reshape(-1, 21, 3) exp_deltas = exps - standard_exp # calc weights lip_deltas = [] for lip_idx in [6, 12, 14, 17, 19, 20]: lip_deltas.append(exp_deltas[:, lip_idx, :]) lip_deltas = torch.stack(lip_deltas, dim=1) # T, 6, 3 lip_deltas = lip_deltas.view(lip_deltas.shape[0], -1) lip_dist = torch.sum(lip_deltas ** 2, dim=-1, keepdim=True) max_dist = torch.max(lip_dist, dim=0)[0].squeeze() # 1 weight = (torch.sigmoid(lip_dist*alpha) - 0.5) / (max_dist * beta + 0.05) # modulation for lip_idx in [6, 12, 14, 17, 19, 20]: exps[:, lip_idx, :] = standard_exp[:, lip_idx, :] + exp_deltas[:, lip_idx, :] * (1 + weight) motions[:, :63] = exps.flatten(-2, -1) return motions def get_motion_sequence(self, motion_data: torch.Tensor, rescale_ratio=1.0): n_frames = motion_data.shape[0] # denorm if self.motion_mean_std is not None: if motion_data.shape[1] > 70: motion_data[:, :63] = motion_data[:, :63] * (self.motion_mean_std["std"][:, :63] + 1e-5) + self.motion_mean_std["mean"][:, :63] # denorm pose motion_data[:, 63:] = motion_data[:, 63:] + self.motion_mean_std["mean"][:, 63:] else: motion_data = motion_data * (self.motion_mean_std["std"] + 1e-5) + self.motion_mean_std["mean"] kp_infos = {"exp": [], "scale": [], "t": [], "pitch": [], "yaw": [], "roll": []} for idx in range(n_frames): exp = motion_data[idx][:63] scale = motion_data[idx][63:64] * rescale_ratio t = motion_data[idx][64:67] * rescale_ratio if motion_data.shape[1] > 70: pitch = motion_data[idx][67:133] yaw = motion_data[idx][133:199] roll = motion_data[idx][199:265] else: pitch = motion_data[idx][67:68] yaw = motion_data[idx][68:69] roll = motion_data[idx][69:70] kp_infos["exp"].append(exp) kp_infos["scale"].append(scale) kp_infos["t"].append(t) kp_infos["pitch"].append(pitch) kp_infos["yaw"].append(yaw) kp_infos["roll"].append(roll) for k, v in kp_infos.items(): kp_infos[k] = torch.stack(v) return kp_infos def get_prev_motion(self, x_s_info): kp_infos = [] x_s_info["t"][:, 2] = 0 # zero tz if self.motion_generator is not None and self.motion_generator.input_dim == 70: x_s_info = self.motion_processer.refine_kp(x_s_info) for k, v in x_s_info.items(): x_s_info[k] = v.reshape(1, -1) rescale_ratio = 1.0 if self.motion_mean_std is None else (x_s_info["scale"] + 1e-5) / (self.motion_mean_std["mean"][:, 63:64] + 1e-5) for feat_name in ["exp", "scale", "t", "pitch", "yaw", "roll"]: if feat_name in ["scale", "t"]: # set scale as the mean scale kp_infos.append(x_s_info[feat_name] / rescale_ratio) else: kp_infos.append(x_s_info[feat_name]) kp_infos = torch.cat(kp_infos, dim=-1) # B, D # normalize if self.motion_mean_std is not None: # normalize exp if self.motion_generator is not None and self.motion_generator.input_dim > 70: kp_infos[:, :63] = (kp_infos[:, :63] - self.motion_mean_std["mean"][:, :63]) / (self.motion_mean_std["std"][:, :63] + 1e-5) # normalize pose kp_infos[:, 63:] = kp_infos[:, 63:] - self.motion_mean_std["mean"][:, 63:] else: kp_infos = (kp_infos - self.motion_mean_std["mean"]) / (self.motion_mean_std["std"] + 1e-5) kp_infos = kp_infos.unsqueeze(1) # B, D return kp_infos, rescale_ratio def process_audio(self, audio_path: str, silent_audio_path = None, mode="post"): # add silent audio to pad short input ori_audio_path = audio_path audio_path, add_frames = self.audio_processor.add_silent_audio(audio_path, silent_audio_path, add_duration=2, linear_fusion=False, mode=mode) audio_emb = self.audio_processor.get_long_audio_emb(audio_path) return audio_emb, audio_path, add_frames, ori_audio_path def driven_sample(self, image_path: str, audio_path: str, cfg_scale: float=1., emo: int=8, save_dir=None, smooth=False, silent_audio_path = None, silent_mode="post"): assert self.motion_generator is not None, f"Motion Generator is not set" reference_name = osp.basename(image_path).split('.')[0] audio_name = osp.basename(audio_path).split('.')[0] # get audio embeddings audio_emb, audio_path, add_frames, ori_audio_path = self.process_audio(audio_path, silent_audio_path, mode=silent_mode) # get src image infos source_rgb_lst = self.motion_processer.read_image(image_path) src_img_256x256, s_lmk, crop_info = self.motion_processer.crop_image(source_rgb_lst[0], do_crop=True) f_s, x_s_info = self.motion_processer.prepare_source(src_img_256x256) prev_motion, rescale_ratio = self.get_prev_motion(x_s_info) # generate motions motion = self.motion_generator.sample(audio_emb, x_s_info["kp"], prev_motion=prev_motion, cfg_scale=cfg_scale, emo=emo) if add_frames > 0: standard_motion = motion[-max(add_frames*3//4, 1)] motion = self.modulate_lip(standard_motion, motion, alpha=5) if silent_mode == "both": motion = motion[add_frames:-add_frames] elif silent_mode == "pre": motion = motion[add_frames:] else: motion = motion[:-add_frames] print(f"length of motion: {len(motion)}") kp_infos = self.get_motion_sequence(motion, rescale_ratio=rescale_ratio) # driven results if save_dir is None: save_dir = self.cfg.output_dir if not osp.exists(save_dir): os.makedirs(save_dir) #save_path = osp.join(save_dir, f'{reference_name}_{audio_name}_cfg-{cfg_scale}_emo-{emo_map[emo]}.mp4') save_path = osp.join(save_dir, f'{reference_name}.mp4') self.motion_processer.driven_by_audio(source_rgb_lst[0], kp_infos, save_path, ori_audio_path, smooth=smooth) return save_path def viz_motion(self, motion_data): pass def __call__(self): pass if __name__ == "__main__": import time import random import argparse parser = argparse.ArgumentParser(description="Arguments for the task") parser.add_argument('--task', type=str, default="test", help='Task to perform') parser.add_argument('--cfg_path', type=str, default="configs/audio2motion/inference/inference.yaml", help='Path to configuration file') parser.add_argument('--image_path', type=str, default="src/examples/reference_images/6.jpg", help='Path to the input image') parser.add_argument('--audio_path', type=str, default="src/examples/driving_audios/5.wav", help='Path to the driving audio') parser.add_argument('--silent_audio_path', type=str, default="src/examples/silent-audio.wav", help='Path to silent audio file') parser.add_argument('--save_dir', type=str, default="output/", help='Directory to save results') parser.add_argument('--motion_mean_std_path', type=str, default="src/datasets/mean.pt", help='Path to motion mean and standard deviation file') parser.add_argument('--cfg_scale', type=float, default=1.2, help='Scaling factor for the configuration') args = parser.parse_args() pipeline = LiveVASAPipeline(cfg_path=args.cfg_path, motion_mean_std_path=args.motion_mean_std_path) emo=8 if not osp.exists(args.save_dir): os.makedirs(args.save_dir) save_dir = osp.join(args.save_dir, f"cfg-{args.cfg_scale}-emo-{emo_map[emo]}") if not osp.exists(save_dir): os.makedirs(save_dir) video_path = pipeline.driven_sample( args.image_path, args.audio_path, cfg_scale=args.cfg_scale, emo=emo, save_dir=save_dir, smooth=False, silent_audio_path = args.silent_audio_path, ) print(f"Video Result has been saved into: {video_path}")