TalkSHOW / scripts /continuity.py
feifeifeiliu's picture
first version
865fd8a
import os
import sys
from transformers import Wav2Vec2Processor
from visualise.rendering import RenderTool
sys.path.append(os.getcwd())
from glob import glob
import numpy as np
import json
import smplx as smpl
from nets import *
from trainer.options import parse_args
from data_utils import torch_data
from trainer.config import load_JsonConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from scripts.diversity import init_model, init_dataloader, get_vertices
from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
import time
global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
def infer(data_root, g_body, g_face, g_body2, exp_name, infer_loader, infer_set, device, norm_stats, smplx,
smplx_model, rendertool, args=None, config=None, var=None):
am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
am_sr = 16000
num_sample = 1
face = False
if face:
body_static = torch.zeros([1, 162], device='cuda')
body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
stand = False
j = 0
gt_0 = None
for bat in infer_loader:
poses_ = bat['poses'].to(torch.float32).to(device)
if poses_.shape[-1] == 300:
j = j + 1
if j > 1000:
continue
id = bat['speaker'].to('cuda') - 20
if config.Data.pose.expression:
expression = bat['expression'].to(device).to(torch.float32)
poses = torch.cat([poses_, expression], dim=1)
else:
poses = poses_
cur_wav_file = bat['aud_file'][0]
betas = bat['betas'][0].to(torch.float64).to('cuda')
# betas = torch.zeros([1, 300], dtype=torch.float64).to('cuda')
gt = poses.to('cuda').squeeze().transpose(1, 0)
if config.Data.pose.normalization:
gt = denormalize(gt, norm_stats[0], norm_stats[1]).squeeze(dim=0)
if config.Data.pose.convert_to_6d:
if config.Data.pose.expression:
gt_exp = gt[:, -100:]
gt = gt[:, :-100]
gt = gt.reshape(gt.shape[0], -1, 6)
gt = matrix_to_axis_angle(rotation_6d_to_matrix(gt)).reshape(gt.shape[0], -1)
gt = torch.cat([gt, gt_exp], -1)
if face:
gt = torch.cat([gt[:, :3], body_static.repeat(gt.shape[0], 1), gt[:, -100:]], dim=-1)
result_list = [gt]
# cur_wav_file = '.\\training_data\\french-V4.wav'
# pred_face = g_face.infer_on_audio(cur_wav_file,
# initial_pose=poses_,
# norm_stats=None,
# w_pre=False,
# # id=id,
# frame=None,
# am=am,
# am_sr=am_sr
# )
#
# pred_face = torch.tensor(pred_face).squeeze().to('cuda')
pred_face = torch.zeros([gt.shape[0], 103], device='cuda')
pred_jaw = pred_face[:, :3]
pred_face = pred_face[:, 3:]
# id = torch.tensor([0], device='cuda')
for i in range(num_sample):
pred_res = g_body.infer_on_audio(cur_wav_file,
initial_pose=poses_,
norm_stats=norm_stats,
txgfile=None,
id=id,
var=var,
fps=30,
continuity=True,
smooth=False
)
pred = torch.tensor(pred_res).squeeze().to('cuda')
if pred.shape[0] < pred_face.shape[0]:
repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
pred = torch.cat([pred, repeat_frame], dim=0)
else:
pred = pred[:pred_face.shape[0], :]
if config.Data.pose.convert_to_6d:
pred = pred.reshape(pred.shape[0], -1, 6)
pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred))
pred = pred.reshape(pred.shape[0], -1)
pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)
# pred[:, 9:12] = global_orient
pred = part2full(pred, stand)
if face:
pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
# result_list[0] = poses2pred(result_list[0], stand)
# if gt_0 is None:
# gt_0 = gt
# pred = pred2poses(pred, gt_0)
# result_list[0] = poses2poses(result_list[0], gt_0)
result_list.append(pred)
vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)
result_list = [res.to('cpu') for res in result_list]
dict = np.concatenate(result_list[1:], axis=0)
file_name = 'visualise/video/' + config.Log.name + '/' + \
cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
np.save(file_name, dict)
rendertool._render_continuity(cur_wav_file, vertices_list[1], frame=60)
def main():
parser = parse_args()
args = parser.parse_args()
device = torch.device(args.gpu)
torch.cuda.set_device(device)
config = load_JsonConfig(args.config_file)
smplx = True
os.environ['smplx_npz_path'] = config.smplx_npz_path
os.environ['extra_joint_path'] = config.extra_joint_path
os.environ['j14_regressor_path'] = config.j14_regressor_path
print('init model...')
body_model_name = 's2g_body_pixel'
body_model_path = './experiments/2022-12-31-smplx_S2G-body-pixel-conti-wide/ckpt-99.pth' # './experiments/2022-10-09-smplx_S2G-body-pixel-aud-3p/ckpt-99.pth'
generator = init_model(body_model_name, body_model_path, args, config)
# face_model_name = 's2g_face'
# face_model_path = './experiments/2022-10-15-smplx_S2G-face-sgd-3p-wv2/ckpt-99.pth' # './experiments/2022-09-28-smplx_S2G-face-faceformer-3d/ckpt-99.pth'
# generator_face = init_model(face_model_name, face_model_path, args, config)
generator_face = None
print('init dataloader...')
infer_set, infer_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
print('init smlpx model...')
dtype = torch.float64
model_params = dict(model_path='E:/PycharmProjects/Motion-Projects/models',
model_type='smplx',
create_global_orient=True,
create_body_pose=True,
create_betas=True,
num_betas=300,
create_left_hand_pose=True,
create_right_hand_pose=True,
use_pca=False,
flat_hand_mean=False,
create_expression=True,
num_expression_coeffs=100,
num_pca_comps=12,
create_jaw_pose=True,
create_leye_pose=True,
create_reye_pose=True,
create_transl=False,
# gender='ne',
dtype=dtype, )
smplx_model = smpl.create(**model_params).to('cuda')
print('init rendertool...')
rendertool = RenderTool('visualise/video/' + config.Log.name)
infer(config.Data.data_root, generator, generator_face, None, args.exp_name, infer_loader, infer_set, device,
norm_stats, smplx, smplx_model, rendertool, args, config, (None, None))
if __name__ == '__main__':
main()