Spaces:
Configuration error
Configuration error
# -*- coding: utf-8 -*- | |
# @Time : 2024/12/15 | |
# @Author : wenshao | |
# @Email : wenshaoguo1026@gmail.com | |
# @Project : FasterLivePortrait | |
# @FileName: joyvasa_audio_to_motion_pipeline.py | |
import math | |
import pdb | |
import torch | |
import torchaudio | |
import numpy as np | |
import torch.nn.functional as F | |
import pickle | |
from tqdm import tqdm | |
import pathlib | |
import os | |
from ..models.JoyVASA.dit_talking_head import DitTalkingHead | |
from ..models.JoyVASA.helper import NullableArgs | |
from ..utils import utils | |
class JoyVASAAudio2MotionPipeline: | |
""" | |
JoyVASA 声音生成LivePortrait Motion | |
""" | |
def __init__(self, **kwargs): | |
self.device, self.dtype = utils.get_opt_device_dtype() | |
# Check if the operating system is Windows | |
if os.name == 'nt': | |
temp = pathlib.PosixPath | |
pathlib.PosixPath = pathlib.WindowsPath | |
motion_model_path = kwargs.get("motion_model_path", "") | |
audio_model_path = kwargs.get("audio_model_path", "") | |
motion_template_path = kwargs.get("motion_template_path", "") | |
model_data = torch.load(motion_model_path, map_location="cpu") | |
model_args = NullableArgs(model_data['args']) | |
model = DitTalkingHead(motion_feat_dim=model_args.motion_feat_dim, | |
n_motions=model_args.n_motions, | |
n_prev_motions=model_args.n_prev_motions, | |
feature_dim=model_args.feature_dim, | |
audio_model=model_args.audio_model, | |
n_diff_steps=model_args.n_diff_steps, | |
audio_encoder_path=audio_model_path) | |
model_data['model'].pop('denoising_net.TE.pe') | |
model.load_state_dict(model_data['model'], strict=False) | |
model.to(self.device, dtype=self.dtype) | |
model.eval() | |
# Restore the original PosixPath if it was changed | |
if os.name == 'nt': | |
pathlib.PosixPath = temp | |
self.motion_generator = model | |
self.n_motions = model_args.n_motions | |
self.n_prev_motions = model_args.n_prev_motions | |
self.fps = model_args.fps | |
self.audio_unit = 16000. / self.fps # num of samples per frame | |
self.n_audio_samples = round(self.audio_unit * self.n_motions) | |
self.pad_mode = model_args.pad_mode | |
self.use_indicator = model_args.use_indicator | |
self.cfg_mode = kwargs.get("cfg_mode", "incremental") | |
self.cfg_cond = kwargs.get("cfg_cond", None) | |
self.cfg_scale = kwargs.get("cfg_scale", 2.8) | |
with open(motion_template_path, 'rb') as fin: | |
self.templete_dict = pickle.load(fin) | |
def gen_motion_sequence(self, audio_path, **kwargs): | |
# preprocess audio | |
audio, sample_rate = torchaudio.load(audio_path) | |
if sample_rate != 16000: | |
audio = torchaudio.functional.resample( | |
audio, | |
orig_freq=sample_rate, | |
new_freq=16000, | |
) | |
audio = audio.mean(0).to(self.device, dtype=self.dtype) | |
# audio = F.pad(audio, (1280, 640), "constant", 0) | |
# audio_mean, audio_std = torch.mean(audio), torch.std(audio) | |
# audio = (audio - audio_mean) / (audio_std + 1e-5) | |
# crop audio into n_subdivision according to n_motions | |
clip_len = int(len(audio) / 16000 * self.fps) | |
stride = self.n_motions | |
if clip_len <= self.n_motions: | |
n_subdivision = 1 | |
else: | |
n_subdivision = math.ceil(clip_len / stride) | |
# padding | |
n_padding_audio_samples = self.n_audio_samples * n_subdivision - len(audio) | |
n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit) | |
if n_padding_audio_samples > 0: | |
if self.pad_mode == 'zero': | |
padding_value = 0 | |
elif self.pad_mode == 'replicate': | |
padding_value = audio[-1] | |
else: | |
raise ValueError(f'Unknown pad mode: {self.pad_mode}') | |
audio = F.pad(audio, (0, n_padding_audio_samples), value=padding_value) | |
# generate motions | |
coef_list = [] | |
for i in range(0, n_subdivision): | |
start_idx = i * stride | |
end_idx = start_idx + self.n_motions | |
indicator = torch.ones((1, self.n_motions)).to(self.device) if self.use_indicator else None | |
if indicator is not None and i == n_subdivision - 1 and n_padding_frames > 0: | |
indicator[:, -n_padding_frames:] = 0 | |
audio_in = audio[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0) | |
if i == 0: | |
motion_feat, noise, prev_audio_feat = self.motion_generator.sample(audio_in, | |
indicator=indicator, | |
cfg_mode=self.cfg_mode, | |
cfg_cond=self.cfg_cond, | |
cfg_scale=self.cfg_scale, | |
dynamic_threshold=0) | |
else: | |
motion_feat, noise, prev_audio_feat = self.motion_generator.sample(audio_in, | |
prev_motion_feat.to(self.dtype), | |
prev_audio_feat.to(self.dtype), | |
noise.to(self.dtype), | |
indicator=indicator, | |
cfg_mode=self.cfg_mode, | |
cfg_cond=self.cfg_cond, | |
cfg_scale=self.cfg_scale, | |
dynamic_threshold=0) | |
prev_motion_feat = motion_feat[:, -self.n_prev_motions:].clone() | |
prev_audio_feat = prev_audio_feat[:, -self.n_prev_motions:] | |
motion_coef = motion_feat | |
if i == n_subdivision - 1 and n_padding_frames > 0: | |
motion_coef = motion_coef[:, :-n_padding_frames] # delete padded frames | |
coef_list.append(motion_coef) | |
motion_coef = torch.cat(coef_list, dim=1) | |
# motion_coef = self.reformat_motion(args, motion_coef) | |
motion_coef = motion_coef.squeeze().cpu().numpy().astype(np.float32) | |
motion_list = [] | |
for idx in tqdm(range(motion_coef.shape[0]), total=motion_coef.shape[0]): | |
exp = motion_coef[idx][:63] * self.templete_dict["std_exp"] + self.templete_dict["mean_exp"] | |
scale = motion_coef[idx][63:64] * ( | |
self.templete_dict["max_scale"] - self.templete_dict["min_scale"]) + self.templete_dict[ | |
"min_scale"] | |
t = motion_coef[idx][64:67] * (self.templete_dict["max_t"] - self.templete_dict["min_t"]) + \ | |
self.templete_dict["min_t"] | |
pitch = motion_coef[idx][67:68] * ( | |
self.templete_dict["max_pitch"] - self.templete_dict["min_pitch"]) + self.templete_dict[ | |
"min_pitch"] | |
yaw = motion_coef[idx][68:69] * (self.templete_dict["max_yaw"] - self.templete_dict["min_yaw"]) + \ | |
self.templete_dict["min_yaw"] | |
roll = motion_coef[idx][69:70] * (self.templete_dict["max_roll"] - self.templete_dict["min_roll"]) + \ | |
self.templete_dict["min_roll"] | |
R = utils.get_rotation_matrix(pitch, yaw, roll) | |
R = R.reshape(1, 3, 3).astype(np.float32) | |
exp = exp.reshape(1, 21, 3).astype(np.float32) | |
scale = scale.reshape(1, 1).astype(np.float32) | |
t = t.reshape(1, 3).astype(np.float32) | |
pitch = pitch.reshape(1, 1).astype(np.float32) | |
yaw = yaw.reshape(1, 1).astype(np.float32) | |
roll = roll.reshape(1, 1).astype(np.float32) | |
motion_list.append({"exp": exp, "scale": scale, "R": R, "t": t, "pitch": pitch, "yaw": yaw, "roll": roll}) | |
tgt_motion = {'n_frames': motion_coef.shape[0], 'output_fps': self.fps, 'motion': motion_list, 'c_eyes_lst': [], | |
'c_lip_lst': []} | |
return tgt_motion | |