TalkSHOW / nets /smplx_face.py
feifeifeiliu's picture
first version
865fd8a
raw
history blame
7.7 kB
import os
import sys
sys.path.append(os.getcwd())
from nets.layers import *
from nets.base import TrainWrapperBaseClass
# from nets.spg.faceformer import Faceformer
from nets.spg.s2g_face import Generator as s2g_face
from losses import KeypointLoss
from nets.utils import denormalize
from data_utils import get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
from sklearn.preprocessing import normalize
import smplx
class TrainWrapper(TrainWrapperBaseClass):
'''
a wrapper receving a batch from data_utils and calculate loss
'''
def __init__(self, args, config):
self.args = args
self.config = config
self.device = torch.device(self.args.gpu)
self.global_step = 0
self.convert_to_6d = self.config.Data.pose.convert_to_6d
self.expression = self.config.Data.pose.expression
self.epoch = 0
self.init_params()
self.num_classes = 4
self.generator = s2g_face(
n_poses=self.config.Data.pose.generate_length,
each_dim=self.each_dim,
dim_list=self.dim_list,
training=not self.args.infer,
device=self.device,
identity=False if self.convert_to_6d else True,
num_classes=self.num_classes,
).to(self.device)
# self.generator = Faceformer().to(self.device)
self.discriminator = None
self.am = None
self.MSELoss = KeypointLoss().to(self.device)
super().__init__(args, config)
def init_optimizer(self):
self.generator_optimizer = optim.SGD(
filter(lambda p: p.requires_grad,self.generator.parameters()),
lr=0.001,
momentum=0.9,
nesterov=False,
)
def init_params(self):
if self.convert_to_6d:
scale = 2
else:
scale = 1
global_orient = round(3 * scale)
leye_pose = reye_pose = round(3 * scale)
jaw_pose = round(3 * scale)
body_pose = round(63 * scale)
left_hand_pose = right_hand_pose = round(45 * scale)
if self.expression:
expression = 100
else:
expression = 0
b_j = 0
jaw_dim = jaw_pose
b_e = b_j + jaw_dim
eye_dim = leye_pose + reye_pose
b_b = b_e + eye_dim
body_dim = global_orient + body_pose
b_h = b_b + body_dim
hand_dim = left_hand_pose + right_hand_pose
b_f = b_h + hand_dim
face_dim = expression
self.dim_list = [b_j, b_e, b_b, b_h, b_f]
self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim + face_dim
self.pose = int(self.full_dim / round(3 * scale))
self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
def __call__(self, bat):
# assert (not self.args.infer), "infer mode"
self.global_step += 1
total_loss = None
loss_dict = {}
aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
id = bat['speaker'].to(self.device) - 20
id = F.one_hot(id, self.num_classes)
aud = aud.permute(0, 2, 1)
gt_poses = poses.permute(0, 2, 1)
if self.expression:
expression = bat['expression'].to(self.device).to(torch.float32)
gt_poses = torch.cat([gt_poses, expression.permute(0, 2, 1)], dim=2)
pred_poses, _ = self.generator(
aud,
gt_poses,
id,
)
G_loss, G_loss_dict = self.get_loss(
pred_poses=pred_poses,
gt_poses=gt_poses,
pre_poses=None,
mode='training_G',
gt_conf=None,
aud=aud,
)
self.generator_optimizer.zero_grad()
G_loss.backward()
grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm)
loss_dict['grad'] = grad.item()
self.generator_optimizer.step()
for key in list(G_loss_dict.keys()):
loss_dict[key] = G_loss_dict.get(key, 0).item()
return total_loss, loss_dict
def get_loss(self,
pred_poses,
gt_poses,
pre_poses,
aud,
mode='training_G',
gt_conf=None,
exp=1,
gt_nzero=None,
pre_nzero=None,
):
loss_dict = {}
[b_j, b_e, b_b, b_h, b_f] = self.dim_list
MSELoss = torch.mean(torch.abs(pred_poses[:, :, :6] - gt_poses[:, :, :6]))
if self.expression:
expl = torch.mean((pred_poses[:, :, -100:] - gt_poses[:, :, -100:])**2)
else:
expl = 0
gen_loss = expl + MSELoss
loss_dict['MSELoss'] = MSELoss
if self.expression:
loss_dict['exp_loss'] = expl
return gen_loss, loss_dict
def infer_on_audio(self, aud_fn, id=None, initial_pose=None, norm_stats=None, w_pre=False, frame=None, am=None, am_sr=16000, **kwargs):
'''
initial_pose: (B, C, T), normalized
(aud_fn, txgfile) -> generated motion (B, T, C)
'''
output = []
# assert self.args.infer, "train mode"
self.generator.eval()
if self.config.Data.pose.normalization:
assert norm_stats is not None
data_mean = norm_stats[0]
data_std = norm_stats[1]
# assert initial_pose.shape[-1] == pre_length
if initial_pose is not None:
gt = initial_pose[:,:,:].permute(0, 2, 1).to(self.generator.device).to(torch.float32)
pre_poses = initial_pose[:,:,:15].permute(0, 2, 1).to(self.generator.device).to(torch.float32)
poses = initial_pose.permute(0, 2, 1).to(self.generator.device).to(torch.float32)
B = pre_poses.shape[0]
else:
gt = None
pre_poses=None
B = 1
if type(aud_fn) == torch.Tensor:
aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.generator.device)
num_poses_to_generate = aud_feat.shape[-1]
else:
aud_feat = get_mfcc_ta(aud_fn, am=am, am_sr=am_sr, fps=30, encoder_choice='faceformer')
aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.generator.device).transpose(1, 2)
if frame is None:
frame = aud_feat.shape[2]*30//16000
#
if id is None:
id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device)
else:
id = F.one_hot(id, self.num_classes).to(self.generator.device)
with torch.no_grad():
pred_poses = self.generator(aud_feat, pre_poses, id, time_steps=frame)[0]
pred_poses = pred_poses.cpu().numpy()
output = pred_poses
if self.config.Data.pose.normalization:
output = denormalize(output, data_mean, data_std)
return output
def generate(self, wv2_feat, frame):
'''
initial_pose: (B, C, T), normalized
(aud_fn, txgfile) -> generated motion (B, T, C)
'''
output = []
# assert self.args.infer, "train mode"
self.generator.eval()
B = 1
id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device)
id = id.repeat(wv2_feat.shape[0], 1)
with torch.no_grad():
pred_poses = self.generator(wv2_feat, None, id, time_steps=frame)[0]
return pred_poses