FasterLivepotrait / src /pipelines /joyvasa_audio_to_motion_pipeline.py
AndroidGuy's picture
Add files with Git LFS support
8dc9718
# -*- coding: utf-8 -*-
# @Time : 2024/12/15
# @Author : wenshao
# @Email : wenshaoguo1026@gmail.com
# @Project : FasterLivePortrait
# @FileName: joyvasa_audio_to_motion_pipeline.py
import math
import pdb
import torch
import torchaudio
import numpy as np
import torch.nn.functional as F
import pickle
from tqdm import tqdm
import pathlib
import os
from ..models.JoyVASA.dit_talking_head import DitTalkingHead
from ..models.JoyVASA.helper import NullableArgs
from ..utils import utils
class JoyVASAAudio2MotionPipeline:
"""
JoyVASA 声音生成LivePortrait Motion
"""
def __init__(self, **kwargs):
self.device, self.dtype = utils.get_opt_device_dtype()
# Check if the operating system is Windows
if os.name == 'nt':
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath
motion_model_path = kwargs.get("motion_model_path", "")
audio_model_path = kwargs.get("audio_model_path", "")
motion_template_path = kwargs.get("motion_template_path", "")
model_data = torch.load(motion_model_path, map_location="cpu")
model_args = NullableArgs(model_data['args'])
model = DitTalkingHead(motion_feat_dim=model_args.motion_feat_dim,
n_motions=model_args.n_motions,
n_prev_motions=model_args.n_prev_motions,
feature_dim=model_args.feature_dim,
audio_model=model_args.audio_model,
n_diff_steps=model_args.n_diff_steps,
audio_encoder_path=audio_model_path)
model_data['model'].pop('denoising_net.TE.pe')
model.load_state_dict(model_data['model'], strict=False)
model.to(self.device, dtype=self.dtype)
model.eval()
# Restore the original PosixPath if it was changed
if os.name == 'nt':
pathlib.PosixPath = temp
self.motion_generator = model
self.n_motions = model_args.n_motions
self.n_prev_motions = model_args.n_prev_motions
self.fps = model_args.fps
self.audio_unit = 16000. / self.fps # num of samples per frame
self.n_audio_samples = round(self.audio_unit * self.n_motions)
self.pad_mode = model_args.pad_mode
self.use_indicator = model_args.use_indicator
self.cfg_mode = kwargs.get("cfg_mode", "incremental")
self.cfg_cond = kwargs.get("cfg_cond", None)
self.cfg_scale = kwargs.get("cfg_scale", 2.8)
with open(motion_template_path, 'rb') as fin:
self.templete_dict = pickle.load(fin)
@torch.inference_mode()
def gen_motion_sequence(self, audio_path, **kwargs):
# preprocess audio
audio, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000:
audio = torchaudio.functional.resample(
audio,
orig_freq=sample_rate,
new_freq=16000,
)
audio = audio.mean(0).to(self.device, dtype=self.dtype)
# audio = F.pad(audio, (1280, 640), "constant", 0)
# audio_mean, audio_std = torch.mean(audio), torch.std(audio)
# audio = (audio - audio_mean) / (audio_std + 1e-5)
# crop audio into n_subdivision according to n_motions
clip_len = int(len(audio) / 16000 * self.fps)
stride = self.n_motions
if clip_len <= self.n_motions:
n_subdivision = 1
else:
n_subdivision = math.ceil(clip_len / stride)
# padding
n_padding_audio_samples = self.n_audio_samples * n_subdivision - len(audio)
n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit)
if n_padding_audio_samples > 0:
if self.pad_mode == 'zero':
padding_value = 0
elif self.pad_mode == 'replicate':
padding_value = audio[-1]
else:
raise ValueError(f'Unknown pad mode: {self.pad_mode}')
audio = F.pad(audio, (0, n_padding_audio_samples), value=padding_value)
# generate motions
coef_list = []
for i in range(0, n_subdivision):
start_idx = i * stride
end_idx = start_idx + self.n_motions
indicator = torch.ones((1, self.n_motions)).to(self.device) if self.use_indicator else None
if indicator is not None and i == n_subdivision - 1 and n_padding_frames > 0:
indicator[:, -n_padding_frames:] = 0
audio_in = audio[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0)
if i == 0:
motion_feat, noise, prev_audio_feat = self.motion_generator.sample(audio_in,
indicator=indicator,
cfg_mode=self.cfg_mode,
cfg_cond=self.cfg_cond,
cfg_scale=self.cfg_scale,
dynamic_threshold=0)
else:
motion_feat, noise, prev_audio_feat = self.motion_generator.sample(audio_in,
prev_motion_feat.to(self.dtype),
prev_audio_feat.to(self.dtype),
noise.to(self.dtype),
indicator=indicator,
cfg_mode=self.cfg_mode,
cfg_cond=self.cfg_cond,
cfg_scale=self.cfg_scale,
dynamic_threshold=0)
prev_motion_feat = motion_feat[:, -self.n_prev_motions:].clone()
prev_audio_feat = prev_audio_feat[:, -self.n_prev_motions:]
motion_coef = motion_feat
if i == n_subdivision - 1 and n_padding_frames > 0:
motion_coef = motion_coef[:, :-n_padding_frames] # delete padded frames
coef_list.append(motion_coef)
motion_coef = torch.cat(coef_list, dim=1)
# motion_coef = self.reformat_motion(args, motion_coef)
motion_coef = motion_coef.squeeze().cpu().numpy().astype(np.float32)
motion_list = []
for idx in tqdm(range(motion_coef.shape[0]), total=motion_coef.shape[0]):
exp = motion_coef[idx][:63] * self.templete_dict["std_exp"] + self.templete_dict["mean_exp"]
scale = motion_coef[idx][63:64] * (
self.templete_dict["max_scale"] - self.templete_dict["min_scale"]) + self.templete_dict[
"min_scale"]
t = motion_coef[idx][64:67] * (self.templete_dict["max_t"] - self.templete_dict["min_t"]) + \
self.templete_dict["min_t"]
pitch = motion_coef[idx][67:68] * (
self.templete_dict["max_pitch"] - self.templete_dict["min_pitch"]) + self.templete_dict[
"min_pitch"]
yaw = motion_coef[idx][68:69] * (self.templete_dict["max_yaw"] - self.templete_dict["min_yaw"]) + \
self.templete_dict["min_yaw"]
roll = motion_coef[idx][69:70] * (self.templete_dict["max_roll"] - self.templete_dict["min_roll"]) + \
self.templete_dict["min_roll"]
R = utils.get_rotation_matrix(pitch, yaw, roll)
R = R.reshape(1, 3, 3).astype(np.float32)
exp = exp.reshape(1, 21, 3).astype(np.float32)
scale = scale.reshape(1, 1).astype(np.float32)
t = t.reshape(1, 3).astype(np.float32)
pitch = pitch.reshape(1, 1).astype(np.float32)
yaw = yaw.reshape(1, 1).astype(np.float32)
roll = roll.reshape(1, 1).astype(np.float32)
motion_list.append({"exp": exp, "scale": scale, "R": R, "t": t, "pitch": pitch, "yaw": yaw, "roll": roll})
tgt_motion = {'n_frames': motion_coef.shape[0], 'output_fps': self.fps, 'motion': motion_list, 'c_eyes_lst': [],
'c_lip_lst': []}
return tgt_motion