import gradio as gr import os import sys # import OpenGL.GL as gl os.environ["PYOPENGL_PLATFORM"] = "egl" os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1" os.system('pip install /home/user/app/pyrender') sys.path.append('/home/user/app/pyrender') # os.system(r"apt-get install -y python-opengl libosmesa6") sys.path.append(os.getcwd()) # os.system(r"cd mesh-master") # os.system(r"tar -jxvf boost_1_79_0.tar.bz2") # os.system(r"mv boost_1_79_0 boost") # os.system(r"CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/home/user/app/boost") # os.system(r"export LIBRARY_PATH=$LIBRARY_PATH:/home/user/app/boost/stage/lib") # os.system(r"apt-get update") # os.system(r"apt-get install sudo") # # os.system(r"apt-get install libboost-dev") # # os.system(r"sudo apt-get install gcc") # # os.system(r"sudo apt-get install g++") # os.system(r"make -C ./mesh-master all") # os.system(r"cd ..") # os.system("pip install --no-deps --verbose --no-cache-dir /home/user/app/mesh-fix-MSVC_compilation") from transformers import Wav2Vec2Processor import numpy as np import json import smplx as smpl from nets import * from trainer.options import parse_args from data_utils import torch_data from trainer.config import load_JsonConfig import torch import torch.nn as nn import torch.nn.functional as F from torch.utils import data from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses from visualise.rendering import RenderTool global device is_cuda = torch.cuda.is_available() device = torch.device("cuda" if is_cuda else "cpu") def init_model(model_name, model_path, args, config): if model_name == 's2g_face': generator = s2g_face( args, config, ) elif model_name == 's2g_body_vq': generator = s2g_body_vq( args, config, ) elif model_name == 's2g_body_pixel': generator = s2g_body_pixel( args, config, ) elif model_name == 's2g_LS3DCG': generator = LS3DCG( args, config, ) else: raise NotImplementedError model_ckpt = torch.load(model_path, map_location=torch.device('cpu')) if model_name == 'smplx_S2G': generator.generator.load_state_dict(model_ckpt['generator']['generator']) elif 'generator' in list(model_ckpt.keys()): generator.load_state_dict(model_ckpt['generator']) else: model_ckpt = {'generator': model_ckpt} generator.load_state_dict(model_ckpt) return generator def get_vertices(smplx_model, betas, result_list, exp, require_pose=False): vertices_list = [] poses_list = [] expression = torch.zeros([1, 100]) for i in result_list: vertices = [] poses = [] for j in range(i.shape[0]): output = smplx_model(betas=betas, expression=i[j][165:265].unsqueeze_(dim=0) if exp else expression, jaw_pose=i[j][0:3].unsqueeze_(dim=0), leye_pose=i[j][3:6].unsqueeze_(dim=0), reye_pose=i[j][6:9].unsqueeze_(dim=0), global_orient=i[j][9:12].unsqueeze_(dim=0), body_pose=i[j][12:75].unsqueeze_(dim=0), left_hand_pose=i[j][75:120].unsqueeze_(dim=0), right_hand_pose=i[j][120:165].unsqueeze_(dim=0), return_verts=True) vertices.append(output.vertices.detach().cpu().numpy().squeeze()) # pose = torch.cat([output.body_pose, output.left_hand_pose, output.right_hand_pose], dim=1) pose = output.body_pose poses.append(pose.detach().cpu()) vertices = np.asarray(vertices) vertices_list.append(vertices) poses = torch.cat(poses, dim=0) poses_list.append(poses) if require_pose: return vertices_list, poses_list else: return vertices_list, None global_orient = torch.tensor([3.0747, -0.0158, -0.0152]) parser = parse_args() args = parser.parse_args() args.gpu = device RUN_MODE = "local" if RUN_MODE != "local": os.system("wget -P experiments/2022-10-15-smplx_S2G-face-3d/ " "https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/2022-10-15-smplx_S2G-face-3d/ckpt-99.pth") os.system("wget -P experiments/2022-10-31-smplx_S2G-body-vq-3d/ " "https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/2022-10-31-smplx_S2G-body-vq-3d/ckpt-99.pth") os.system("wget -P experiments/2022-11-02-smplx_S2G-body-pixel-3d/ " "https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/2022-11-02-smplx_S2G-body-pixel-3d/ckpt-99.pth") os.system("wget -P visualise/smplx/ " "https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/smplx/SMPLX_NEUTRAL.npz") config = load_JsonConfig("config/body_pixel.json") face_model_name = args.face_model_name face_model_path = args.face_model_path body_model_name = args.body_model_name body_model_path = args.body_model_path smplx_path = './visualise/' os.environ['smplx_npz_path'] = config.smplx_npz_path os.environ['extra_joint_path'] = config.extra_joint_path os.environ['j14_regressor_path'] = config.j14_regressor_path print('init model...') g_body = init_model(body_model_name, body_model_path, args, config) generator2 = None g_face = init_model(face_model_name, face_model_path, args, config) print('init smlpx model...') dtype = torch.float64 model_params = dict(model_path=smplx_path, model_type='smplx', create_global_orient=True, create_body_pose=True, create_betas=True, num_betas=300, create_left_hand_pose=True, create_right_hand_pose=True, use_pca=False, flat_hand_mean=False, create_expression=True, num_expression_coeffs=100, num_pca_comps=12, create_jaw_pose=True, create_leye_pose=True, create_reye_pose=True, create_transl=False, # gender='ne', dtype=dtype, ) smplx_model = smpl.create(**model_params).to(device) print('init rendertool...') rendertool = RenderTool('visualise/video/' + config.Log.name) def infer(wav, identity, pose): betas = torch.zeros([1, 300], dtype=torch.float64).to(device) am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme") am_sr = 16000 num_sample = args.num_sample cur_wav_file = wav if pose == 'Stand': stand = True face = False elif pose == 'Sit': stand = False face = False else: stand = False face = True if face: body_static = torch.zeros([1, 162], device=device) body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1) if identity == 'Oliver': id = 0 elif identity == 'Chemistry': id = 1 elif identity == 'Seth': id = 2 elif identity == 'Conan': id = 3 result_list = [] pred_face = g_face.infer_on_audio(cur_wav_file, initial_pose=None, norm_stats=None, w_pre=False, # id=id, frame=None, am=am, am_sr=am_sr ) pred_face = torch.tensor(pred_face).squeeze().to(device) # pred_face = torch.zeros([gt.shape[0], 105]) if config.Data.pose.convert_to_6d: pred_jaw = pred_face[:, :6].reshape(pred_face.shape[0], -1, 6) pred_jaw = matrix_to_axis_angle(rotation_6d_to_matrix(pred_jaw)).reshape(pred_face.shape[0], -1) pred_face = pred_face[:, 6:] else: pred_jaw = pred_face[:, :3] pred_face = pred_face[:, 3:] id = torch.tensor([id], device=device) for i in range(num_sample): pred_res = g_body.infer_on_audio(cur_wav_file, initial_pose=None, norm_stats=None, txgfile=None, id=id, var=None, fps=30, w_pre=False ) pred = torch.tensor(pred_res).squeeze().to(device) if pred.shape[0] < pred_face.shape[0]: repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1) pred = torch.cat([pred, repeat_frame], dim=0) else: pred = pred[:pred_face.shape[0], :] body_or_face = False if pred.shape[1] < 275: body_or_face = True if config.Data.pose.convert_to_6d: pred = pred.reshape(pred.shape[0], -1, 6) pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred)) pred = pred.reshape(pred.shape[0], -1) if config.Model.model_name == 's2g_LS3DCG': pred = torch.cat([pred[:, :3], pred[:, 103:], pred[:, 3:103]], dim=-1) else: pred = torch.cat([pred_jaw, pred, pred_face], dim=-1) # pred[:, 9:12] = global_orient pred = part2full(pred, stand) if face: pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1) # result_list[0] = poses2pred(result_list[0], stand) # if gt_0 is None: # gt_0 = gt # pred = pred2poses(pred, gt_0) # result_list[0] = poses2poses(result_list[0], gt_0) result_list.append(pred) vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression) result_list = [res.to('cpu') for res in result_list] dict = np.concatenate(result_list[:], axis=0) rendertool._render_sequences(cur_wav_file, vertices_list, stand=stand, face=face, whole_body=args.whole_body) return "result.mp4" def main(): iface = gr.Interface(fn=infer, inputs=["audio", gr.Radio(["Oliver", "Chemistry", "Seth", "Conan"]), gr.Radio(["Stand", "Sit", "Only Face"]), ], outputs="video", examples=[[os.path.join(os.path.dirname(__file__), "demo_audio/style.wav"), "Oliver", "Sit"]]) iface.launch(debug=True) if __name__ == '__main__': main()