SynTalker / app.py
robinwitch's picture
add
5f87b33
import os
import signal
import time
import csv
import sys
import warnings
import random
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import numpy as np
import time
import pprint
from loguru import logger
import smplx
from torch.utils.tensorboard import SummaryWriter
import wandb
import matplotlib.pyplot as plt
from utils import config, logger_tools, other_tools_hf, metric, data_transfer, other_tools
from dataloaders import data_tools
from dataloaders.build_vocab import Vocab
from optimizers.optim_factory import create_optimizer
from optimizers.scheduler_factory import create_scheduler
from optimizers.loss_factory import get_loss_func
from dataloaders.data_tools import joints_list
from utils import rotation_conversions as rc
import soundfile as sf
import librosa
import subprocess
from transformers import pipeline
from diffusion.model_util import create_gaussian_diffusion
from diffusion.resample import create_named_schedule_sampler
from models.vq.model import RVQVAE
import spaces
import pickle
os.environ['PYOPENGL_PLATFORM']='egl'
command = ["bash","./demo/install_mfa1.sh"]
result = subprocess.run(command, capture_output=True, text=True)
print("debug0: ", result)
# command = ["bash","./demo/install_mfa.sh"]
# result = subprocess.run(command, capture_output=True, text=True)
# print("debug1: ", result)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-tiny.en",
chunk_length_s=30,
device='cpu',
)
# @spaces.GPU()
def run_pipeline(audio):
return pipe(audio, batch_size=8)["text"]
debug = False
class BaseTrainer(object):
def __init__(self, args,ap):
args.use_ddim=True
hf_dir = "hf"
time_local = time.localtime()
time_name_expend = "%02d%02d_%02d%02d%02d_"%(time_local[1], time_local[2],time_local[3], time_local[4], time_local[5])
self.time_name_expend = time_name_expend
tmp_dir = args.out_path + "custom/"+ time_name_expend + hf_dir
if not os.path.exists(tmp_dir + "/"):
os.makedirs(tmp_dir + "/")
self.audio_path = tmp_dir + "/tmp.wav"
sf.write(self.audio_path, ap[1], ap[0])
audio, ssr = librosa.load(self.audio_path,sr=args.audio_sr)
# use asr model to get corresponding text transcripts
file_path = tmp_dir+"/tmp.lab"
self.textgrid_path = tmp_dir + "/tmp.TextGrid"
if not debug:
text = run_pipeline(audio)
with open(file_path, "w", encoding="utf-8") as file:
file.write(text)
# use montreal forced aligner to get textgrid
# command = ["mfa", "align", tmp_dir, "english_us_arpa", "english_us_arpa", tmp_dir]
# result = subprocess.run(command, capture_output=True, text=True)
# print("debug2: ", result)
command = ["bash","./demo/run_mfa.sh", tmp_dir]
result = subprocess.run(command, capture_output=True, text=True)
print("debug2: ", result)
ap = (ssr, audio)
self.args = args
self.rank = 0 # dist.get_rank()
args.textgrid_file_path = self.textgrid_path
args.audio_file_path = self.audio_path
self.rank = 0 # dist.get_rank()
self.checkpoint_path = tmp_dir
args.tmp_dir = tmp_dir
self.test_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "test")
self.test_loader = torch.utils.data.DataLoader(
self.test_data,
batch_size=1,
shuffle=False,
num_workers=args.loader_workers,
drop_last=False,
)
logger.info(f"Init test dataloader success")
from models.denoiser import MDM
self.model = MDM(args)
if self.rank == 0:
logger.info(self.model)
logger.info(f"init {args.g_name} success")
self.args = args
self.ori_joint_list = joints_list[self.args.ori_joints]
self.tar_joint_list_face = joints_list["beat_smplx_face"]
self.tar_joint_list_upper = joints_list["beat_smplx_upper"]
self.tar_joint_list_hands = joints_list["beat_smplx_hands"]
self.tar_joint_list_lower = joints_list["beat_smplx_lower"]
self.joint_mask_face = np.zeros(len(list(self.ori_joint_list.keys()))*3)
self.joints = 55
for joint_name in self.tar_joint_list_face:
self.joint_mask_face[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
self.joint_mask_upper = np.zeros(len(list(self.ori_joint_list.keys()))*3)
for joint_name in self.tar_joint_list_upper:
self.joint_mask_upper[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
self.joint_mask_hands = np.zeros(len(list(self.ori_joint_list.keys()))*3)
for joint_name in self.tar_joint_list_hands:
self.joint_mask_hands[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
self.joint_mask_lower = np.zeros(len(list(self.ori_joint_list.keys()))*3)
for joint_name in self.tar_joint_list_lower:
self.joint_mask_lower[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
self.tracker = other_tools.EpochTracker(["fid", "l1div", "bc", "rec", "trans", "vel", "transv", 'dis', 'gen', 'acc', 'transa', 'exp', 'lvd', 'mse', "cls", "rec_face", "latent", "cls_full", "cls_self", "cls_word", "latent_word","latent_self","predict_x0_loss"], [False,True,True, False, False, False, False, False, False, False, False, False, False, False, False, False, False,False, False, False,False,False,False])
vq_model_module = __import__(f"models.motion_representation", fromlist=["something"])
self.args.vae_layer = 2
self.args.vae_length = 256
self.args.vae_test_dim = 106
# self.vq_model_face = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank)
# other_tools.load_checkpoints(self.vq_model_face, "./datasets/hub/pretrained_vq/face_vertex_1layer_790.bin", args.e_name)
vq_type = self.args.vqvae_type
if vq_type=="vqvae":
self.args.vae_layer = 4
self.args.vae_test_dim = 78
self.vq_model_upper = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank)
other_tools.load_checkpoints(self.vq_model_upper, args.vqvae_upper_path, args.e_name)
self.args.vae_test_dim = 180
self.vq_model_hands = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank)
other_tools.load_checkpoints(self.vq_model_hands, args.vqvae_hands_path, args.e_name)
self.args.vae_test_dim = 54
self.args.vae_layer = 4
self.vq_model_lower = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank)
other_tools.load_checkpoints(self.vq_model_lower, args.vqvae_lower_path, args.e_name)
self.args.vae_test_dim = 61
self.args.vae_layer = 4
self.args.vae_test_dim = 330
self.args.vae_layer = 4
self.args.vae_length = 240
self.cls_loss = nn.NLLLoss().to(self.rank)
self.reclatent_loss = nn.MSELoss().to(self.rank)
self.vel_loss = torch.nn.L1Loss(reduction='mean').to(self.rank)
self.rec_loss = get_loss_func("GeodesicLoss").to(self.rank)
self.log_softmax = nn.LogSoftmax(dim=2).to(self.rank)
self.diffusion = create_gaussian_diffusion(use_ddim=args.use_ddim)
self.schedule_sampler_type = 'uniform'
self.schedule_sampler = create_named_schedule_sampler(self.schedule_sampler_type, self.diffusion)
self.mean = np.load(args.mean_pose_path)
self.std = np.load(args.std_pose_path)
self.use_trans = args.use_trans
if self.use_trans:
self.trans_mean = np.load(args.mean_trans_path)
self.trans_std = np.load(args.std_trans_path)
joints = [3,6,9,12,13,14,15,16,17,18,19,20,21]
upper_body_mask = []
for i in joints:
upper_body_mask.extend([i*6, i*6+1, i*6+2, i*6+3, i*6+4, i*6+5])
joints = list(range(25,55))
hands_body_mask = []
for i in joints:
hands_body_mask.extend([i*6, i*6+1, i*6+2, i*6+3, i*6+4, i*6+5])
joints = [0,1,2,4,5,7,8,10,11]
lower_body_mask = []
for i in joints:
lower_body_mask.extend([i*6, i*6+1, i*6+2, i*6+3, i*6+4, i*6+5])
self.mean_upper = self.mean[upper_body_mask]
self.mean_hands = self.mean[hands_body_mask]
self.mean_lower = self.mean[lower_body_mask]
self.std_upper = self.std[upper_body_mask]
self.std_hands = self.std[hands_body_mask]
self.std_lower = self.std[lower_body_mask]
def inverse_selection(self, filtered_t, selection_array, n):
original_shape_t = np.zeros((n, selection_array.size))
selected_indices = np.where(selection_array == 1)[0]
for i in range(n):
original_shape_t[i, selected_indices] = filtered_t[i]
return original_shape_t
def inverse_selection_tensor(self, filtered_t, selection_array, n):
selection_array = torch.from_numpy(selection_array).cuda()
original_shape_t = torch.zeros((n, 165)).cuda()
selected_indices = torch.where(selection_array == 1)[0]
for i in range(n):
original_shape_t[i, selected_indices] = filtered_t[i]
return original_shape_t
def test_demo(self, epoch):
'''
input audio and text, output motion
do not calculate loss and metric
save video
'''
results_save_path = self.checkpoint_path + f"/{epoch}/"
if os.path.exists(results_save_path):
import shutil
shutil.rmtree(results_save_path)
os.makedirs(results_save_path)
start_time = time.time()
total_length = 0
test_seq_list = self.test_data.selected_file
align = 0
latent_out = []
latent_ori = []
l2_all = 0
lvel = 0
# self.eval_copy.eval()
with torch.no_grad():
for its, batch_data in enumerate(self.test_loader):
# loaded_data = self._load_data(batch_data)
# net_out = self._g_test(loaded_data)
try:
net_out = _warp(self.args,self.model, batch_data,self.joints,self.joint_mask_upper,self.joint_mask_hands,self.joint_mask_lower,self.use_trans,self.mean_upper,self.mean_hands,self.mean_lower,self.std_upper,self.std_hands,self.std_lower,self.trans_mean,self.trans_std)
print("debug8: return try")
except:
print("debug9: return fail, use pickle load file")
with open("tmp_file", "rb") as tmp_file:
net_out = pickle.load(tmp_file)
tar_pose = net_out['tar_pose']
rec_pose = net_out['rec_pose']
tar_exps = net_out['tar_exps']
tar_beta = net_out['tar_beta']
rec_trans = net_out['rec_trans']
tar_trans = net_out['tar_trans']
rec_exps = net_out['rec_exps']
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
if (30/self.args.pose_fps) != 1:
assert 30%self.args.pose_fps == 0
n *= int(30/self.args.pose_fps)
tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1)
rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1)
rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6))
rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6)
tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6))
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6))
rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3)
tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6))
tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3)
tar_pose_np = tar_pose.numpy()
rec_pose_np = rec_pose.numpy()
rec_trans_np = rec_trans.numpy().reshape(bs*n, 3)
rec_exp_np = rec_exps.numpy().reshape(bs*n, 100)
tar_exp_np = tar_exps.numpy().reshape(bs*n, 100)
tar_trans_np = tar_trans.numpy().reshape(bs*n, 3)
gt_npz = np.load("./demo/examples/2_scott_0_1_1.npz", allow_pickle=True)
results_npz_file_save_path = results_save_path+f"result_{self.time_name_expend[:-1]}"+'.npz'
np.savez(results_npz_file_save_path,
betas=gt_npz["betas"],
poses=rec_pose_np,
expressions=rec_exp_np,
trans=rec_trans_np,
model='smplx2020',
gender='neutral',
mocap_frame_rate = 30,
)
total_length += n
render_vid_path = None
if self.args.render_video:
render_vid_path = other_tools_hf.render_one_sequence_no_gt(
results_npz_file_save_path,
# results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz',
results_save_path,
self.audio_path,
self.args.data_path_1+"smplx_models/",
use_matplotlib = False,
args = self.args,
)
result = [
gr.Video(value=render_vid_path, visible=True),
gr.File(value=results_npz_file_save_path, label="download motion and visualize in blender"),
]
end_time = time.time() - start_time
logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion")
return result
@spaces.GPU(duration=60)
def _warp(args,model, batch_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,use_trans,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std):
diffusion = create_gaussian_diffusion(use_ddim=args.use_ddim)
args,model,vq_model_upper,vq_model_hands,vq_model_lower,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vqvae_latent_scale=_warp_create_cuda_model(args,model,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std)
loaded_data = _warp_load_data(
batch_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,args,use_trans,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vq_model_upper,vq_model_hands,vq_model_lower
)
net_out = _warp_g_test(loaded_data,diffusion,args,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,model,vqvae_latent_scale,vq_model_upper,vq_model_hands,vq_model_lower,use_trans,trans_std,trans_mean,std_upper,std_hands,std_lower,mean_upper,mean_hands,mean_lower)
with open("tmp_file", "wb") as tmp_file:
pickle.dump(net_out, tmp_file)
return net_out
def _warp_inverse_selection_tensor(filtered_t, selection_array, n):
selection_array = torch.from_numpy(selection_array).cuda()
original_shape_t = torch.zeros((n, 165)).cuda()
selected_indices = torch.where(selection_array == 1)[0]
for i in range(n):
original_shape_t[i, selected_indices] = filtered_t[i]
return original_shape_t
def _warp_g_test(loaded_data,diffusion,args,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,model,vqvae_latent_scale,vq_model_upper,vq_model_hands,vq_model_lower,use_trans,trans_std,trans_mean,std_upper,std_hands,std_lower,mean_upper,mean_hands,mean_lower):
sample_fn = diffusion.p_sample_loop
if args.use_ddim:
sample_fn = diffusion.ddim_sample_loop
mode = 'test'
bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], joints
tar_pose = loaded_data["tar_pose"]
tar_beta = loaded_data["tar_beta"]
tar_exps = loaded_data["tar_exps"]
tar_contact = loaded_data["tar_contact"]
tar_trans = loaded_data["tar_trans"]
in_word = loaded_data["in_word"]
in_audio = loaded_data["in_audio"]
in_x0 = loaded_data['latent_in']
in_seed = loaded_data['latent_in']
remain = n%8
if remain != 0:
tar_pose = tar_pose[:, :-remain, :]
tar_beta = tar_beta[:, :-remain, :]
tar_trans = tar_trans[:, :-remain, :]
in_word = in_word[:, :-remain]
tar_exps = tar_exps[:, :-remain, :]
tar_contact = tar_contact[:, :-remain, :]
in_x0 = in_x0[:, :in_x0.shape[1]-(remain//args.vqvae_squeeze_scale), :]
in_seed = in_seed[:, :in_x0.shape[1]-(remain//args.vqvae_squeeze_scale), :]
n = n - remain
tar_pose_jaw = tar_pose[:, :, 66:69]
tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
tar_pose_hands = tar_pose[:, :, 25*3:55*3]
tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
tar_pose_upper = tar_pose[:, :, joint_mask_upper.astype(bool)]
tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
tar_pose_leg = tar_pose[:, :, joint_mask_lower.astype(bool)]
tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2)
tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
rec_all_face = []
rec_all_upper = []
rec_all_lower = []
rec_all_hands = []
vqvae_squeeze_scale = args.vqvae_squeeze_scale
roundt = (n - args.pre_frames * vqvae_squeeze_scale) // (args.pose_length - args.pre_frames * vqvae_squeeze_scale)
remain = (n - args.pre_frames * vqvae_squeeze_scale) % (args.pose_length - args.pre_frames * vqvae_squeeze_scale)
round_l = args.pose_length - args.pre_frames * vqvae_squeeze_scale
print("debug3:finish it!")
for i in range(0, roundt):
in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+args.pre_frames * vqvae_squeeze_scale]
in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*args.pre_frames * vqvae_squeeze_scale]
in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+args.pre_frames]
in_seed_tmp = in_seed[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+args.pre_frames]
in_x0_tmp = in_x0[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+args.pre_frames]
mask_val = torch.ones(bs, args.pose_length, args.pose_dims+3+4).float().cuda()
mask_val[:, :args.pre_frames, :] = 0.0
if i == 0:
in_seed_tmp = in_seed_tmp[:, :args.pre_frames, :]
else:
in_seed_tmp = last_sample[:, -args.pre_frames:, :]
cond_ = {'y':{}}
cond_['y']['audio'] = in_audio_tmp
cond_['y']['word'] = in_word_tmp
cond_['y']['id'] = in_id_tmp
cond_['y']['seed'] =in_seed_tmp
cond_['y']['mask'] = (torch.zeros([args.batch_size, 1, 1, args.pose_length]) < 1).cuda()
cond_['y']['style_feature'] = torch.zeros([bs, 512]).cuda()
shape_ = (bs, 1536, 1, 32)
sample = sample_fn(
model,
shape_,
clip_denoised=False,
model_kwargs=cond_,
skip_timesteps=0,
init_image=None,
progress=True,
dump_steps=None,
noise=None,
const_noise=False,
)
sample = sample.squeeze().permute(1,0).unsqueeze(0)
last_sample = sample.clone()
rec_latent_upper = sample[...,:512]
rec_latent_hands = sample[...,512:1024]
rec_latent_lower = sample[...,1024:1536]
if i == 0:
rec_all_upper.append(rec_latent_upper)
rec_all_hands.append(rec_latent_hands)
rec_all_lower.append(rec_latent_lower)
else:
rec_all_upper.append(rec_latent_upper[:, args.pre_frames:])
rec_all_hands.append(rec_latent_hands[:, args.pre_frames:])
rec_all_lower.append(rec_latent_lower[:, args.pre_frames:])
print("debug4:finish it!")
rec_all_upper = torch.cat(rec_all_upper, dim=1) * vqvae_latent_scale
rec_all_hands = torch.cat(rec_all_hands, dim=1) * vqvae_latent_scale
rec_all_lower = torch.cat(rec_all_lower, dim=1) * vqvae_latent_scale
rec_upper = vq_model_upper.latent2origin(rec_all_upper)[0]
rec_hands = vq_model_hands.latent2origin(rec_all_hands)[0]
rec_lower = vq_model_lower.latent2origin(rec_all_lower)[0]
if use_trans:
rec_trans_v = rec_lower[...,-3:]
rec_trans_v = rec_trans_v * trans_std + trans_mean
rec_trans = torch.zeros_like(rec_trans_v)
rec_trans = torch.cumsum(rec_trans_v, dim=-2)
rec_trans[...,1]=rec_trans_v[...,1]
rec_lower = rec_lower[...,:-3]
if args.pose_norm:
rec_upper = rec_upper * std_upper + mean_upper
rec_hands = rec_hands * std_hands + mean_hands
rec_lower = rec_lower * std_lower + mean_lower
n = n - remain
tar_pose = tar_pose[:, :n, :]
tar_exps = tar_exps[:, :n, :]
tar_trans = tar_trans[:, :n, :]
tar_beta = tar_beta[:, :n, :]
rec_exps = tar_exps
#rec_pose_jaw = rec_face[:, :, :6]
rec_pose_legs = rec_lower[:, :, :54]
bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1]
rec_pose_upper = rec_upper.reshape(bs, n, 13, 6)
rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)#
rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3)
rec_pose_upper_recover = _warp_inverse_selection_tensor(rec_pose_upper, joint_mask_upper, bs*n)
rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6)
rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower)
rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6)
rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3)
rec_pose_lower_recover = _warp_inverse_selection_tensor(rec_pose_lower, joint_mask_lower, bs*n)
rec_pose_hands = rec_hands.reshape(bs, n, 30, 6)
rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands)
rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3)
rec_pose_hands_recover = _warp_inverse_selection_tensor(rec_pose_hands, joint_mask_hands, bs*n)
rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover
rec_pose[:, 66:69] = tar_pose.reshape(bs*n, 55*3)[:, 66:69]
rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3))
rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6)
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3))
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
print("debug5:finish it!")
return {
'rec_pose': rec_pose.detach().cpu(),
'rec_trans': rec_trans.detach().cpu(),
'tar_pose': tar_pose.detach().cpu(),
'tar_exps': tar_exps.detach().cpu(),
'tar_beta': tar_beta.detach().cpu(),
'tar_trans': tar_trans.detach().cpu(),
'rec_exps': rec_exps.detach().cpu(),
}
def _warp_load_data(dict_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,args,use_trans,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vq_model_upper,vq_model_hands,vq_model_lower):
tar_pose_raw = dict_data["pose"]
tar_pose = tar_pose_raw[:, :, :165].cuda()
tar_contact = tar_pose_raw[:, :, 165:169].cuda()
tar_trans = dict_data["trans"].cuda()
tar_trans_v = dict_data["trans_v"].cuda()
tar_exps = dict_data["facial"].cuda()
in_audio = dict_data["audio"].cuda()
in_word = dict_data["word"].cuda()
tar_beta = dict_data["beta"].cuda()
tar_id = dict_data["id"].cuda().long()
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], joints
tar_pose_jaw = tar_pose[:, :, 66:69]
tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
tar_pose_hands = tar_pose[:, :, 25*3:55*3]
tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
tar_pose_upper = tar_pose[:, :, joint_mask_upper.astype(bool)]
tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
tar_pose_leg = tar_pose[:, :, joint_mask_lower.astype(bool)]
tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
tar_pose_lower = tar_pose_leg
tar4dis = torch.cat([tar_pose_jaw, tar_pose_upper, tar_pose_hands, tar_pose_leg], dim=2)
if args.pose_norm:
tar_pose_upper = (tar_pose_upper - mean_upper) / std_upper
tar_pose_hands = (tar_pose_hands - mean_hands) / std_hands
tar_pose_lower = (tar_pose_lower - mean_lower) / std_lower
if use_trans:
tar_trans_v = (tar_trans_v - trans_mean)/trans_std
tar_pose_lower = torch.cat([tar_pose_lower,tar_trans_v], dim=-1)
latent_face_top = None#self.vq_model_face.map2latent(tar_pose_face) # bs*n/4
latent_upper_top = vq_model_upper.map2latent(tar_pose_upper)
latent_hands_top = vq_model_hands.map2latent(tar_pose_hands)
latent_lower_top = vq_model_lower.map2latent(tar_pose_lower)
latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2)/args.vqvae_latent_scale
tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
style_feature = None
if args.use_motionclip:
motionclip_feat = tar_pose_6d[...,:22*6]
batch = {}
bs,seq,feat = motionclip_feat.shape
batch['x']=motionclip_feat.permute(0,2,1).contiguous()
batch['y']=torch.zeros(bs).int().cuda()
batch['mask']=torch.ones([bs,seq]).bool().cuda()
style_feature = motionclip.encoder(batch)['mu'].detach().float()
# print(tar_index_value_upper_top.shape, index_in.shape)
return {
"tar_pose_jaw": tar_pose_jaw,
"tar_pose_face": tar_pose_face,
"tar_pose_upper": tar_pose_upper,
"tar_pose_lower": tar_pose_lower,
"tar_pose_hands": tar_pose_hands,
'tar_pose_leg': tar_pose_leg,
"in_audio": in_audio,
"in_word": in_word,
"tar_trans": tar_trans,
"tar_exps": tar_exps,
"tar_beta": tar_beta,
"tar_pose": tar_pose,
"tar4dis": tar4dis,
"latent_face_top": latent_face_top,
"latent_upper_top": latent_upper_top,
"latent_hands_top": latent_hands_top,
"latent_lower_top": latent_lower_top,
"latent_in": latent_in,
"tar_id": tar_id,
"latent_all": latent_all,
"tar_pose_6d": tar_pose_6d,
"tar_contact": tar_contact,
"style_feature":style_feature,
}
def _warp_create_cuda_model(args,model,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std):
args = args
other_tools.load_checkpoints(model, args.test_ckpt, args.g_name)
args.num_quantizers = 6
args.shared_codebook = False
args.quantize_dropout_prob = 0.2
args.mu = 0.99
args.nb_code = 512
args.code_dim = 512
args.code_dim = 512
args.down_t = 2
args.stride_t = 2
args.width = 512
args.depth = 3
args.dilation_growth_rate = 3
args.vq_act = "relu"
args.vq_norm = None
dim_pose = 78
args.body_part = "upper"
vq_model_upper = RVQVAE(args,
dim_pose,
args.nb_code,
args.code_dim,
args.code_dim,
args.down_t,
args.stride_t,
args.width,
args.depth,
args.dilation_growth_rate,
args.vq_act,
args.vq_norm)
dim_pose = 180
args.body_part = "hands"
vq_model_hands = RVQVAE(args,
dim_pose,
args.nb_code,
args.code_dim,
args.code_dim,
args.down_t,
args.stride_t,
args.width,
args.depth,
args.dilation_growth_rate,
args.vq_act,
args.vq_norm)
dim_pose = 54
if args.use_trans:
dim_pose = 57
args.vqvae_lower_path = args.vqvae_lower_trans_path
args.body_part = "lower"
vq_model_lower = RVQVAE(args,
dim_pose,
args.nb_code,
args.code_dim,
args.code_dim,
args.down_t,
args.stride_t,
args.width,
args.depth,
args.dilation_growth_rate,
args.vq_act,
args.vq_norm)
vq_model_upper.load_state_dict(torch.load(args.vqvae_upper_path)['net'])
vq_model_hands.load_state_dict(torch.load(args.vqvae_hands_path)['net'])
vq_model_lower.load_state_dict(torch.load(args.vqvae_lower_path)['net'])
vqvae_latent_scale = args.vqvae_latent_scale
vq_model_upper.eval().cuda()
vq_model_hands.eval().cuda()
vq_model_lower.eval().cuda()
model = model.cuda()
model.eval()
mean_upper = torch.from_numpy(mean_upper).cuda()
mean_hands = torch.from_numpy(mean_hands).cuda()
mean_lower = torch.from_numpy(mean_lower).cuda()
std_upper = torch.from_numpy(std_upper).cuda()
std_hands = torch.from_numpy(std_hands).cuda()
std_lower = torch.from_numpy(std_lower).cuda()
trans_mean = torch.from_numpy(trans_mean).cuda()
trans_std = torch.from_numpy(trans_std).cuda()
return args,model,vq_model_upper,vq_model_hands,vq_model_lower,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vqvae_latent_scale
@logger.catch
def syntalker(audio_path,sample_stratege,render_video):
args = config.parse_args()
args.use_ddim=True
args.render_video=True
print("sample_stratege",sample_stratege)
if sample_stratege==0:
args.use_ddim=True
elif sample_stratege==1:
args.use_ddim=False
if render_video==0:
args.render_video=True
elif render_video==1:
args.render_video=False
print(sample_stratege)
print(args.use_ddim)
#os.environ['TRANSFORMERS_CACHE'] = args.data_path_1 + "hub/"
if not sys.warnoptions:
warnings.simplefilter("ignore")
# dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
#logger_tools.set_args_and_logger(args, rank)
other_tools_hf.set_random_seed(args)
other_tools_hf.print_exp_info(args)
# return one intance of trainer
trainer = BaseTrainer(args, ap = audio_path)
result = trainer.test_demo(999)
return result
examples = [
["demo/examples/2_scott_0_1_1.wav"],
["demo/examples/2_scott_0_2_2.wav"],
["demo/examples/2_scott_0_3_3.wav"],
["demo/examples/2_scott_0_4_4.wav"],
["demo/examples/2_scott_0_5_5.wav"],
]
demo = gr.Interface(
syntalker, # function
inputs=[
# gr.File(label="Please upload SMPL-X file with npz format here.", file_types=["npz", "NPZ"]),
gr.Audio(),
gr.Radio(choices=["DDIM", "DDPM"], label="Please select a sample strategy", type="index", value="DDIM"), # 0 for DDIM, 1 for DDPM
gr.Radio(choices=["Yes", "No"], label="Please select whether render video or not, it will additionally take 10 mintues for rendering", type="index", value="Yes"), # 0 for DDIM, 1 for DDPM
# gr.File(label="Please upload textgrid format file here.", file_types=["TextGrid", "Textgrid", "textgrid"])
], # input type
outputs=[
gr.Video(format="mp4", visible=True),
gr.File(label="download motion and visualize in blender")
],
title='SynTalker: Enabling Synergistic Full-Body Control in Prompt-Based Co-Speech Motion Generation',
description="1. Upload your audio. <br/>\
2. Then, sit back and wait for the rendering to happen! This may take a while (e.g. 2-12 minutes) <br/>\
(The reason of running time so long is that provided GPU have an limitation in GPU running time, we must use CPU to handle some GPU tasks)<br/>\
3. After, you can view the videos. <br/>\
4. Notice that we use a fix face animation, our method only produce body motion. <br/>\
5. Use DDPM sample strategy will generate a better result, while it will take more inference time. \
",
article="Project links: [SynTalker](https://robinwitch.github.io/SynTalker-Page). <br/>\
Reference links: [EMAGE](https://pantomatrix.github.io/EMAGE/). ",
examples=examples,
)
if __name__ == "__main__":
os.environ["MASTER_ADDR"]='127.0.0.1'
os.environ["MASTER_PORT"]='8675'
#os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
demo.launch(share=True)