Spaces:
Paused
Paused
import train | |
import os | |
import time | |
import csv | |
import sys | |
import warnings | |
import random | |
import numpy as np | |
import time | |
import pprint | |
import pickle | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.tensorboard import SummaryWriter | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from loguru import logger | |
import smplx | |
import librosa | |
from utils import config, logger_tools, other_tools, metric | |
from utils import rotation_conversions as rc | |
from dataloaders import data_tools | |
from optimizers.optim_factory import create_optimizer | |
from optimizers.scheduler_factory import create_scheduler | |
from optimizers.loss_factory import get_loss_func | |
from scipy.spatial.transform import Rotation | |
class CustomTrainer(train.BaseTrainer): | |
def __init__(self, args): | |
super().__init__(args) | |
self.joints = self.train_data.joints | |
self.tracker = other_tools.EpochTracker(["fid", "l1div", "bc", "rec", "trans", "vel", "transv", 'dis', 'gen', 'acc', 'transa', 'div_reg', "kl"], [False,True,True, False, False, False, False, False, False, False, False, False, False]) | |
if not self.args.rot6d: #"rot6d" not in args.pose_rep: | |
logger.error(f"this script is for rot6d, your pose rep. is {args.pose_rep}") | |
self.rec_loss = get_loss_func("GeodesicLoss").to(self.rank) | |
self.vel_loss = torch.nn.L1Loss(reduction='mean').to(self.rank) | |
def _load_data(self, dict_data): | |
tar_pose = dict_data["pose"].to(self.rank) | |
tar_trans = dict_data["trans"].to(self.rank) | |
tar_exps = dict_data["facial"].to(self.rank) | |
tar_beta = dict_data["beta"].to(self.rank) | |
tar_id = dict_data["id"].to(self.rank).long() | |
tar_word = dict_data["word"].to(self.rank) | |
in_audio = dict_data["audio"].to(self.rank) | |
in_emo = dict_data["emo"].to(self.rank) | |
#in_sem = dict_data["sem"].to(self.rank) | |
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints | |
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) | |
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) | |
in_pre_pose_cat = torch.cat([tar_pose[:, 0:self.args.pre_frames], tar_trans[:, :self.args.pre_frames]], dim=2).to(self.rank) | |
in_pre_pose = tar_pose.new_zeros((bs, n, j*6+1+3)).to(self.rank) | |
in_pre_pose[:, 0:self.args.pre_frames, :-1] = in_pre_pose_cat[:, 0:self.args.pre_frames] | |
in_pre_pose[:, 0:self.args.pre_frames, -1] = 1 | |
return { | |
"tar_pose": tar_pose, | |
"in_audio": in_audio, | |
"in_motion": in_pre_pose, | |
"tar_trans": tar_trans, | |
"tar_exps": tar_exps, | |
"tar_beta": tar_beta, | |
"tar_word": tar_word, | |
'tar_id': tar_id, | |
'in_emo': in_emo, | |
#'in_sem': in_sem, | |
} | |
def _d_training(self, loaded_data): | |
bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints | |
net_out = self.model(in_audio = loaded_data['in_audio'], pre_seq = loaded_data["in_motion"], in_text=loaded_data["tar_word"], in_id=loaded_data["tar_id"], in_emo=loaded_data["in_emo"], in_facial = loaded_data["tar_exps"]) | |
rec_pose = net_out["rec_pose"][:, :, :j*6] | |
# rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3] | |
rec_pose = rec_pose.reshape(bs, n, j, 6) | |
rec_pose = rc.rotation_6d_to_matrix(rec_pose) | |
rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) | |
tar_pose = rc.rotation_6d_to_matrix(loaded_data["tar_pose"].reshape(bs, n, j, 6)) | |
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) | |
out_d_fake = self.d_model(rec_pose) | |
out_d_real = self.d_model(tar_pose) | |
d_loss_adv = torch.sum(-torch.mean(torch.log(out_d_real + 1e-8) + torch.log(1 - out_d_fake + 1e-8))) | |
self.tracker.update_meter("dis", "train", d_loss_adv.item()) | |
return d_loss_adv | |
def _g_training(self, loaded_data, use_adv, mode="train"): | |
bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints | |
net_out = self.model(in_audio = loaded_data['in_audio'], pre_seq = loaded_data["in_motion"], in_text=loaded_data["tar_word"], in_id=loaded_data["tar_id"], in_emo=loaded_data["in_emo"], in_facial = loaded_data["tar_exps"]) | |
rec_pose = net_out["rec_pose"][:, :, :j*6] | |
rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3] | |
# print(rec_pose.shape, bs, n, j, loaded_data['in_audio'].shape, loaded_data["in_motion"].shape) | |
rec_pose = rec_pose.reshape(bs, n, j, 6) | |
rec_pose = rc.rotation_6d_to_matrix(rec_pose) | |
tar_pose = rc.rotation_6d_to_matrix(loaded_data["tar_pose"].reshape(bs, n, j, 6)) | |
rec_loss = self.rec_loss(tar_pose, rec_pose) | |
rec_loss *= self.args.rec_weight | |
self.tracker.update_meter("rec", mode, rec_loss.item()) | |
# rec_loss_vel = self.vel_loss(rec_pose[:, 1:] - rec_pose[:, :-1], tar_pose[:, 1:] - tar_pose[:, :-1]) | |
# self.tracker.update_meter("vel", mode, rec_loss_vel.item()) | |
# rec_loss_acc = self.vel_loss(rec_pose[:, 2:] - 2*rec_pose[:, 1:-1] + rec_pose[:, :-2], tar_pose[:, 2:] - 2*tar_pose[:, 1:-1] + tar_pose[:, :-2]) | |
# self.tracker.update_meter("acc", mode, rec_loss_acc.item()) | |
rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) | |
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) | |
if self.args.pose_dims < 330 and mode != "train": | |
rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs, n, j, 6)) | |
rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs, n, j*3) | |
rec_pose = self.inverse_selection_tensor(rec_pose, self.train_data.joint_mask, rec_pose.shape[0]) | |
rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs, n, 55, 3)) | |
rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, 55*6) | |
tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6)) | |
tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs, n, j*3) | |
tar_pose = self.inverse_selection_tensor(tar_pose, self.train_data.joint_mask, tar_pose.shape[0]) | |
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) | |
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, 55*6) | |
if use_adv and mode == 'train': | |
out_d_fake = self.d_model(rec_pose) | |
d_loss_adv = -torch.mean(torch.log(out_d_fake + 1e-8)) | |
self.tracker.update_meter("gen", mode, d_loss_adv.item()) | |
else: | |
d_loss_adv = 0 | |
if self.args.train_trans: | |
trans_loss = self.vel_loss(rec_trans, loaded_data["tar_trans"]) | |
trans_loss *= self.args.rec_weight | |
self.tracker.update_meter("trans", mode, trans_loss.item()) | |
else: | |
trans_loss = 0 | |
# trans_loss_vel = self.vel_loss(rec_trans[:, 1:] - rec_trans[:, :-1], loaded_data["tar_trans"][:, 1:] - loaded_data["tar_trans"][:, :-1]) | |
# self.tracker.update_meter("transv", mode, trans_loss_vel.item()) | |
# trans_loss_acc = self.vel_loss(rec_trans[:, 2:] - 2*rec_trans[:, 1:-1] + rec_trans[:, :-2], loaded_data["tar_trans"][:, 2:] - 2*loaded_data["tar_trans"][:, 1:-1] + loaded_data["tar_trans"][:, :-2]) | |
# self.tracker.update_meter("transa", mode, trans_loss_acc.item()) | |
if mode == 'train': | |
return d_loss_adv + rec_loss + trans_loss # + rec_loss_vel + rec_loss_acc + trans_loss_vel + trans_loss_acc | |
elif mode == 'val': | |
return { | |
'rec_pose': rec_pose, | |
'rec_trans': rec_trans, | |
'tar_pose': tar_pose, | |
} | |
else: | |
return { | |
'rec_pose': rec_pose, | |
'rec_trans': rec_trans, | |
'tar_pose': tar_pose, | |
'tar_exps': loaded_data["tar_exps"], | |
'tar_beta': loaded_data["tar_beta"], | |
'tar_trans': loaded_data["tar_trans"], | |
} | |
def train(self, epoch): | |
use_adv = bool(epoch>=self.args.no_adv_epoch) | |
self.model.train() | |
self.d_model.train() | |
self.tracker.reset() | |
t_start = time.time() | |
for its, batch_data in enumerate(self.train_loader): | |
loaded_data = self._load_data(batch_data) | |
t_data = time.time() - t_start | |
if use_adv: | |
d_loss_final = 0 | |
self.opt_d.zero_grad() | |
d_loss_adv = self._d_training(loaded_data) | |
d_loss_final += d_loss_adv | |
d_loss_final.backward() | |
self.opt_d.step() | |
self.opt.zero_grad() | |
g_loss_final = 0 | |
g_loss_final += self._g_training(loaded_data, use_adv, 'train') | |
g_loss_final.backward() | |
self.opt.step() | |
mem_cost = torch.cuda.memory_cached() / 1E9 | |
lr_g = self.opt.param_groups[0]['lr'] | |
lr_d = self.opt_d.param_groups[0]['lr'] | |
t_train = time.time() - t_start - t_data | |
t_start = time.time() | |
if its % self.args.log_period == 0: | |
self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g, lr_d=lr_d) | |
if self.args.debug: | |
if its == 1: break | |
self.opt_s.step(epoch) | |
self.opt_d_s.step(epoch) | |
def val(self, epoch): | |
self.model.eval() | |
self.d_model.eval() | |
with torch.no_grad(): | |
for its, batch_data in enumerate(self.train_loader): | |
loaded_data = self._load_data(batch_data) | |
net_out = self._g_training(loaded_data, False, 'val') | |
tar_pose = net_out['tar_pose'] | |
rec_pose = net_out['rec_pose'] | |
n = tar_pose.shape[1] | |
if (30/self.args.pose_fps) != 1: | |
assert 30%self.args.pose_fps == 0 | |
n *= int(30/self.args.pose_fps) | |
tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
n = tar_pose.shape[1] | |
remain = n%self.args.vae_test_len | |
tar_pose = tar_pose[:, :n-remain, :] | |
rec_pose = rec_pose[:, :n-remain, :] | |
latent_out = self.eval_copy.map2latent(rec_pose).reshape(-1, self.args.vae_length).cpu().numpy() | |
latent_ori = self.eval_copy.map2latent(tar_pose).reshape(-1, self.args.vae_length).cpu().numpy() | |
if its == 0: | |
latent_out_motion_all = latent_out | |
latent_ori_all = latent_ori | |
else: | |
latent_out_motion_all = np.concatenate([latent_out_motion_all, latent_out], axis=0) | |
latent_ori_all = np.concatenate([latent_ori_all, latent_ori], axis=0) | |
if self.args.debug: | |
if its == 1: break | |
fid_motion = data_tools.FIDCalculator.frechet_distance(latent_out_motion_all, latent_ori_all) | |
self.tracker.update_meter("fid", "val", fid_motion) | |
self.val_recording(epoch) | |
def test(self, epoch): | |
results_save_path = self.checkpoint_path + f"/{epoch}/" | |
if os.path.exists(results_save_path): | |
return 0 | |
os.makedirs(results_save_path) | |
start_time = time.time() | |
total_length = 0 | |
test_seq_list = self.test_data.selected_file | |
align = 0 | |
latent_out = [] | |
latent_ori = [] | |
self.model.eval() | |
self.smplx.eval() | |
self.eval_copy.eval() | |
with torch.no_grad(): | |
for its, batch_data in enumerate(self.test_loader): | |
loaded_data = self._load_data(batch_data) | |
net_out = self._g_training(loaded_data, False, 'test') | |
tar_pose = net_out['tar_pose'] | |
rec_pose = net_out['rec_pose'] | |
tar_exps = net_out['tar_exps'] | |
tar_beta = net_out['tar_beta'] | |
rec_trans = net_out['rec_trans'] | |
tar_trans = net_out['tar_trans'] | |
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], 55 | |
if (30/self.args.pose_fps) != 1: | |
assert 30%self.args.pose_fps == 0 | |
n *= int(30/self.args.pose_fps) | |
tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
tar_beta = torch.nn.functional.interpolate(tar_beta.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
tar_exps = torch.nn.functional.interpolate(tar_exps.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
tar_trans = torch.nn.functional.interpolate(tar_trans.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
rec_trans = torch.nn.functional.interpolate(rec_trans.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
# print(rec_pose.shape, tar_pose.shape) | |
# rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) | |
# rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) | |
# tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) | |
# tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) | |
remain = n%self.args.vae_test_len | |
latent_out.append(self.eval_copy.map2latent(rec_pose[:, :n-remain]).reshape(-1, self.args.vae_length).detach().cpu().numpy()) # bs * n/8 * 240 | |
latent_ori.append(self.eval_copy.map2latent(tar_pose[:, :n-remain]).reshape(-1, self.args.vae_length).detach().cpu().numpy()) | |
rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) | |
rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) | |
tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) | |
tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) | |
vertices_rec = self.smplx( | |
betas=tar_beta.reshape(bs*n, 300), | |
transl=rec_trans.reshape(bs*n, 3)-rec_trans.reshape(bs*n, 3), | |
expression=tar_exps.reshape(bs*n, 100)-tar_exps.reshape(bs*n, 100), | |
jaw_pose=rec_pose[:, 66:69], | |
global_orient=rec_pose[:,:3], | |
body_pose=rec_pose[:,3:21*3+3], | |
left_hand_pose=rec_pose[:,25*3:40*3], | |
right_hand_pose=rec_pose[:,40*3:55*3], | |
return_joints=True, | |
leye_pose=rec_pose[:, 69:72], | |
reye_pose=rec_pose[:, 72:75], | |
) | |
# vertices_tar = self.smplx( | |
# betas=tar_beta.reshape(bs*n, 300), | |
# transl=rec_trans.reshape(bs*n, 3)-rec_trans.reshape(bs*n, 3), | |
# expression=tar_exps.reshape(bs*n, 100)-tar_exps.reshape(bs*n, 100), | |
# jaw_pose=tar_pose[:, 66:69], | |
# global_orient=tar_pose[:,:3], | |
# body_pose=tar_pose[:,3:21*3+3], | |
# left_hand_pose=tar_pose[:,25*3:40*3], | |
# right_hand_pose=tar_pose[:,40*3:55*3], | |
# return_joints=True, | |
# leye_pose=tar_pose[:, 69:72], | |
# reye_pose=tar_pose[:, 72:75], | |
# ) | |
joints_rec = vertices_rec["joints"].detach().cpu().numpy().reshape(1, n, 127*3)[0, :n, :55*3] | |
# joints_tar = vertices_tar["joints"].detach().cpu().numpy().reshape(1, n, 127*3)[0, :n, :55*3] | |
_ = self.l1_calculator.run(joints_rec) | |
if self.alignmenter is not None: | |
in_audio_eval, sr = librosa.load(self.args.data_path+"wave16k/"+test_seq_list.iloc[its]['id']+".wav") | |
in_audio_eval = librosa.resample(in_audio_eval, orig_sr=sr, target_sr=self.args.audio_sr) | |
a_offset = int(self.align_mask * (self.args.audio_sr / self.args.pose_fps)) | |
onset_bt = self.alignmenter.load_audio(in_audio_eval[:int(self.args.audio_sr / self.args.pose_fps*n)], a_offset, len(in_audio_eval)-a_offset, True) | |
beat_vel = self.alignmenter.load_pose(joints_rec, self.align_mask, n-self.align_mask, 30, True) | |
# print(beat_vel) | |
align += (self.alignmenter.calculate_align(onset_bt, beat_vel, 30) * (n-2*self.align_mask)) | |
tar_pose_axis_np = tar_pose.detach().cpu().numpy() | |
rec_pose_axis_np = rec_pose.detach().cpu().numpy() | |
rec_trans_np = rec_trans.detach().cpu().numpy().reshape(bs*n, 3) | |
rec_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) - tar_exps.detach().cpu().numpy().reshape(bs*n, 100) | |
tar_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) - tar_exps.detach().cpu().numpy().reshape(bs*n, 100) | |
tar_trans_np = tar_trans.detach().cpu().numpy().reshape(bs*n, 3) | |
gt_npz = np.load(self.args.data_path+self.args.pose_rep +"/"+test_seq_list.iloc[its]['id']+".npz", allow_pickle=True) | |
if not self.args.train_trans: | |
tar_trans_np = tar_trans_np - tar_trans_np | |
rec_trans_np = rec_trans_np - rec_trans_np | |
np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', | |
betas=gt_npz["betas"], | |
poses=tar_pose_axis_np, | |
expressions=tar_exp_np, | |
trans=tar_trans_np, | |
model='smplx2020', | |
gender='neutral', | |
mocap_frame_rate = 30 , | |
) | |
np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', | |
betas=gt_npz["betas"], | |
poses=rec_pose_axis_np, | |
expressions=rec_exp_np, | |
trans=rec_trans_np, | |
model='smplx2020', | |
gender='neutral', | |
mocap_frame_rate = 30, | |
) | |
total_length += n | |
latent_out_all = np.concatenate(latent_out, axis=0) | |
latent_ori_all = np.concatenate(latent_ori, axis=0) | |
fid = data_tools.FIDCalculator.frechet_distance(latent_out_all, latent_ori_all) | |
logger.info(f"fid score: {fid}") | |
self.test_recording("fid", fid, epoch) | |
align_avg = align/(total_length-2*len(self.test_loader)*self.align_mask) | |
logger.info(f"align score: {align_avg}") | |
self.test_recording("bc", align_avg, epoch) | |
l1div = self.l1_calculator.avg() | |
logger.info(f"l1div score: {l1div}") | |
self.test_recording("l1div", l1div, epoch) | |
# data_tools.result2target_vis(self.args.pose_version, results_save_path, results_save_path, self.test_demo, False) | |
end_time = time.time() - start_time | |
logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") |