TalkSHOW / app.py
feifeifeiliu's picture
Upload app.py
0c1c3c2
raw
history blame
11 kB
import gradio as gr
import os
import sys
# import OpenGL.GL as gl
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1"
os.system('pip install /home/user/app/pyrender')
sys.path.append('/home/user/app/pyrender')
# os.system(r"apt-get install -y python-opengl libosmesa6")
sys.path.append(os.getcwd())
# os.system(r"cd mesh-master")
# os.system(r"tar -jxvf boost_1_79_0.tar.bz2")
# os.system(r"mv boost_1_79_0 boost")
# os.system(r"CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/home/user/app/boost")
# os.system(r"export LIBRARY_PATH=$LIBRARY_PATH:/home/user/app/boost/stage/lib")
# os.system(r"apt-get update")
# os.system(r"apt-get install sudo")
#
# os.system(r"apt-get install libboost-dev")
# # os.system(r"sudo apt-get install gcc")
# # os.system(r"sudo apt-get install g++")
# os.system(r"make -C ./mesh-master all")
# os.system(r"cd ..")
# os.system("pip install --no-deps --verbose --no-cache-dir /home/user/app/mesh-fix-MSVC_compilation")
from transformers import Wav2Vec2Processor
import numpy as np
import json
import smplx as smpl
from nets import *
from trainer.options import parse_args
from data_utils import torch_data
from trainer.config import load_JsonConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
from visualise.rendering import RenderTool
global device
is_cuda = torch.cuda.is_available()
device = torch.device("cuda" if is_cuda else "cpu")
def init_model(model_name, model_path, args, config):
if model_name == 's2g_face':
generator = s2g_face(
args,
config,
)
elif model_name == 's2g_body_vq':
generator = s2g_body_vq(
args,
config,
)
elif model_name == 's2g_body_pixel':
generator = s2g_body_pixel(
args,
config,
)
elif model_name == 's2g_LS3DCG':
generator = LS3DCG(
args,
config,
)
else:
raise NotImplementedError
model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
if model_name == 'smplx_S2G':
generator.generator.load_state_dict(model_ckpt['generator']['generator'])
elif 'generator' in list(model_ckpt.keys()):
generator.load_state_dict(model_ckpt['generator'])
else:
model_ckpt = {'generator': model_ckpt}
generator.load_state_dict(model_ckpt)
return generator
def get_vertices(smplx_model, betas, result_list, exp, require_pose=False):
vertices_list = []
poses_list = []
expression = torch.zeros([1, 100])
for i in result_list:
vertices = []
poses = []
for j in range(i.shape[0]):
output = smplx_model(betas=betas,
expression=i[j][165:265].unsqueeze_(dim=0) if exp else expression,
jaw_pose=i[j][0:3].unsqueeze_(dim=0),
leye_pose=i[j][3:6].unsqueeze_(dim=0),
reye_pose=i[j][6:9].unsqueeze_(dim=0),
global_orient=i[j][9:12].unsqueeze_(dim=0),
body_pose=i[j][12:75].unsqueeze_(dim=0),
left_hand_pose=i[j][75:120].unsqueeze_(dim=0),
right_hand_pose=i[j][120:165].unsqueeze_(dim=0),
return_verts=True)
vertices.append(output.vertices.detach().cpu().numpy().squeeze())
# pose = torch.cat([output.body_pose, output.left_hand_pose, output.right_hand_pose], dim=1)
pose = output.body_pose
poses.append(pose.detach().cpu())
vertices = np.asarray(vertices)
vertices_list.append(vertices)
poses = torch.cat(poses, dim=0)
poses_list.append(poses)
if require_pose:
return vertices_list, poses_list
else:
return vertices_list, None
global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
parser = parse_args()
args = parser.parse_args()
args.gpu = device
RUN_MODE = "local"
if RUN_MODE != "local":
os.system("wget -P experiments/2022-10-15-smplx_S2G-face-3d/ "
"https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/2022-10-15-smplx_S2G-face-3d/ckpt-99.pth")
os.system("wget -P experiments/2022-10-31-smplx_S2G-body-vq-3d/ "
"https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/2022-10-31-smplx_S2G-body-vq-3d/ckpt-99.pth")
os.system("wget -P experiments/2022-11-02-smplx_S2G-body-pixel-3d/ "
"https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/2022-11-02-smplx_S2G-body-pixel-3d/ckpt-99.pth")
os.system("wget -P visualise/smplx/ "
"https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/smplx/SMPLX_NEUTRAL.npz")
config = load_JsonConfig("config/body_pixel.json")
face_model_name = args.face_model_name
face_model_path = args.face_model_path
body_model_name = args.body_model_name
body_model_path = args.body_model_path
smplx_path = './visualise/'
os.environ['smplx_npz_path'] = config.smplx_npz_path
os.environ['extra_joint_path'] = config.extra_joint_path
os.environ['j14_regressor_path'] = config.j14_regressor_path
print('init model...')
g_body = init_model(body_model_name, body_model_path, args, config)
generator2 = None
g_face = init_model(face_model_name, face_model_path, args, config)
print('init smlpx model...')
dtype = torch.float64
model_params = dict(model_path=smplx_path,
model_type='smplx',
create_global_orient=True,
create_body_pose=True,
create_betas=True,
num_betas=300,
create_left_hand_pose=True,
create_right_hand_pose=True,
use_pca=False,
flat_hand_mean=False,
create_expression=True,
num_expression_coeffs=100,
num_pca_comps=12,
create_jaw_pose=True,
create_leye_pose=True,
create_reye_pose=True,
create_transl=False,
# gender='ne',
dtype=dtype, )
smplx_model = smpl.create(**model_params).to(device)
print('init rendertool...')
rendertool = RenderTool('visualise/video/' + config.Log.name)
def infer(wav, identity, pose):
betas = torch.zeros([1, 300], dtype=torch.float64).to(device)
am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
am_sr = 16000
num_sample = args.num_sample
cur_wav_file = wav
if pose == 'Stand':
stand = True
face = False
elif pose == 'Sit':
stand = False
face = False
else:
stand = False
face = True
if face:
body_static = torch.zeros([1, 162], device=device)
body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
if identity == 'Oliver':
id = 0
elif identity == 'Chemistry':
id = 1
elif identity == 'Seth':
id = 2
elif identity == 'Conan':
id = 3
result_list = []
pred_face = g_face.infer_on_audio(cur_wav_file,
initial_pose=None,
norm_stats=None,
w_pre=False,
# id=id,
frame=None,
am=am,
am_sr=am_sr
)
pred_face = torch.tensor(pred_face).squeeze().to(device)
# pred_face = torch.zeros([gt.shape[0], 105])
if config.Data.pose.convert_to_6d:
pred_jaw = pred_face[:, :6].reshape(pred_face.shape[0], -1, 6)
pred_jaw = matrix_to_axis_angle(rotation_6d_to_matrix(pred_jaw)).reshape(pred_face.shape[0], -1)
pred_face = pred_face[:, 6:]
else:
pred_jaw = pred_face[:, :3]
pred_face = pred_face[:, 3:]
id = torch.tensor([id], device=device)
for i in range(num_sample):
pred_res = g_body.infer_on_audio(cur_wav_file,
initial_pose=None,
norm_stats=None,
txgfile=None,
id=id,
var=None,
fps=30,
w_pre=False
)
pred = torch.tensor(pred_res).squeeze().to(device)
if pred.shape[0] < pred_face.shape[0]:
repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
pred = torch.cat([pred, repeat_frame], dim=0)
else:
pred = pred[:pred_face.shape[0], :]
body_or_face = False
if pred.shape[1] < 275:
body_or_face = True
if config.Data.pose.convert_to_6d:
pred = pred.reshape(pred.shape[0], -1, 6)
pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred))
pred = pred.reshape(pred.shape[0], -1)
if config.Model.model_name == 's2g_LS3DCG':
pred = torch.cat([pred[:, :3], pred[:, 103:], pred[:, 3:103]], dim=-1)
else:
pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)
# pred[:, 9:12] = global_orient
pred = part2full(pred, stand)
if face:
pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
# result_list[0] = poses2pred(result_list[0], stand)
# if gt_0 is None:
# gt_0 = gt
# pred = pred2poses(pred, gt_0)
# result_list[0] = poses2poses(result_list[0], gt_0)
result_list.append(pred)
vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)
result_list = [res.to('cpu') for res in result_list]
dict = np.concatenate(result_list[:], axis=0)
rendertool._render_sequences(cur_wav_file, vertices_list, stand=stand, face=face, whole_body=args.whole_body)
return "result.mp4"
def main():
iface = gr.Interface(fn=infer, inputs=["audio",
gr.Radio(["Oliver", "Chemistry", "Seth", "Conan"]),
gr.Radio(["Stand", "Sit", "Only Face"]),
],
outputs="video",
examples=[[os.path.join(os.path.dirname(__file__), "demo_audio/style.wav"), "Oliver", "Sit"]])
iface.launch(debug=True)
if __name__ == '__main__':
main()