EMAGE / aeface_trainer.py
H-Liu1997's picture
Upload folder using huggingface_hub
2d47d90 verified
import train
import os
import time
import csv
import sys
import warnings
import random
import numpy as np
import time
import pprint
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP
from loguru import logger
import smplx
from utils import config, logger_tools, other_tools, metric
from utils import rotation_conversions as rc
from dataloaders import data_tools
from optimizers.optim_factory import create_optimizer
from optimizers.scheduler_factory import create_scheduler
from optimizers.loss_factory import get_loss_func
from scipy.spatial.transform import Rotation
class CustomTrainer(train.BaseTrainer):
"""
motion representation learning
"""
def __init__(self, args):
super().__init__(args)
self.joints = self.train_data.joints
self.tracker = other_tools.EpochTracker(["rec", "vel", "acc", "com", "face", "face_vel", "face_acc", "ver", "ver_vel", "ver_acc"], [False, False, False, False, False, False, False, False, False, False])
self.rec_loss = get_loss_func("GeodesicLoss")
self.mse_loss = torch.nn.MSELoss(reduction='mean')
self.vel_loss = torch.nn.MSELoss(reduction='mean') #torch.nn.L1Loss(reduction='mean')
self.vectices_loss = torch.nn.MSELoss(reduction='mean')
def inverse_selection(self, filtered_t, selection_array, n):
# 创建一个全为零的数组,形状为 n*165
original_shape_t = np.zeros((n, selection_array.size))
# 找到选择数组中为1的索引位置
selected_indices = np.where(selection_array == 1)[0]
# 将 filtered_t 的值填充到 original_shape_t 中相应的位置
for i in range(n):
original_shape_t[i, selected_indices] = filtered_t[i]
return original_shape_t
def train(self, epoch):
self.model.train()
t_start = time.time()
self.tracker.reset()
for its, dict_data in enumerate(self.train_loader):
tar_pose = dict_data["pose"]
tar_beta = dict_data["beta"].cuda()
tar_trans = dict_data["trans"].cuda()
tar_pose = tar_pose.cuda()
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
tar_exps = dict_data["facial"].to(self.rank)
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3))
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
in_tar_pose = torch.cat([tar_pose, tar_exps], -1) # 103
t_data = time.time() - t_start
self.opt.zero_grad()
g_loss_final = 0
net_out = self.model(in_tar_pose)
# jaw open 6d loss
rec_pose = net_out["rec_pose"][:, :, :j*6]
rec_pose = rec_pose.reshape(bs, n, j, 6)
rec_pose = rc.rotation_6d_to_matrix(rec_pose)#
tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6))
loss_rec = self.rec_loss(rec_pose, tar_pose) * self.args.rec_weight * self.args.rec_pos_weight
self.tracker.update_meter("rec", "train", loss_rec.item())
g_loss_final += loss_rec
# jaw open 6d vel and acc loss
velocity_loss = self.vel_loss(rec_pose[:, 1:] - rec_pose[:, :-1], tar_pose[:, 1:] - tar_pose[:, :-1]) * self.args.rec_weight
acceleration_loss = self.vel_loss(rec_pose[:, 2:] + rec_pose[:, :-2] - 2 * rec_pose[:, 1:-1], tar_pose[:, 2:] + tar_pose[:, :-2] - 2 * tar_pose[:, 1:-1]) * self.args.rec_weight
self.tracker.update_meter("vel", "train", velocity_loss.item())
self.tracker.update_meter("acc", "train", acceleration_loss.item())
g_loss_final += velocity_loss
g_loss_final += acceleration_loss
# face parameter l1 loss
rec_exps = net_out["rec_pose"][:, :, j*6:]
loss_face = self.mse_loss(rec_exps, tar_exps) * self.args.rec_weight
self.tracker.update_meter("face", "train", loss_face.item())
g_loss_final += loss_face
# face parameter l1 vel and acc loss
face_velocity_loss = self.vel_loss(rec_exps[:, 1:] - rec_exps[:, :-1], tar_exps[:, 1:] - tar_exps[:, :-1]) * self.args.rec_weight
face_acceleration_loss = self.vel_loss(rec_exps[:, 2:] + rec_exps[:, :-2] - 2 * rec_exps[:, 1:-1], tar_exps[:, 2:] + tar_exps[:, :-2] - 2 * tar_exps[:, 1:-1]) * self.args.rec_weight
self.tracker.update_meter("face_vel", "train", face_velocity_loss.item())
self.tracker.update_meter("face_acc", "train", face_acceleration_loss.item())
g_loss_final += face_velocity_loss
g_loss_final += face_acceleration_loss
# vertices loss
if self.args.rec_ver_weight > 0:
tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3)
rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3)
vertices_rec = self.smplx(
betas=tar_beta.reshape(bs*n, 300),
transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3),
expression=tar_exps.reshape(bs*n, 100),
jaw_pose=rec_pose,
global_orient=torch.zeros(bs*n, 3).cuda(),
body_pose=torch.zeros(bs*n, 21*3).cuda(),
left_hand_pose=torch.zeros(bs*n, 15*3).cuda(),
right_hand_pose=torch.zeros(bs*n, 15*3).cuda(),
return_verts=True,
# return_joints=True,
leye_pose=torch.zeros(bs*n, 3).cuda(),
reye_pose=torch.zeros(bs*n, 3).cuda(),
)
vertices_tar = self.smplx(
betas=tar_beta.reshape(bs*n, 300),
transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3),
expression=rec_exps.reshape(bs*n, 100),
jaw_pose=tar_pose,
global_orient=torch.zeros(bs*n, 3).cuda(),
body_pose=torch.zeros(bs*n, 21*3).cuda(),
left_hand_pose=torch.zeros(bs*n, 15*3).cuda(),
right_hand_pose=torch.zeros(bs*n, 15*3).cuda(),
return_verts=True,
# return_joints=True,
leye_pose=torch.zeros(bs*n, 3).cuda(),
reye_pose=torch.zeros(bs*n, 3).cuda(),
)
vectices_loss = self.mse_loss(vertices_rec['vertices'], vertices_tar['vertices'])
self.tracker.update_meter("ver", "train", vectices_loss.item()*self.args.rec_weight * self.args.rec_ver_weight)
g_loss_final += vectices_loss*self.args.rec_weight*self.args.rec_ver_weight
# vertices vel and acc loss
vert_velocity_loss = self.vel_loss(vertices_rec['vertices'][:, 1:] - vertices_rec['vertices'][:, :-1], vertices_tar['vertices'][:, 1:] - vertices_tar['vertices'][:, :-1]) * self.args.rec_weight * self.args.rec_ver_weight
vert_acceleration_loss = self.vel_loss(vertices_rec['vertices'][:, 2:] + vertices_rec['vertices'][:, :-2] - 2 * vertices_rec['vertices'][:, 1:-1], vertices_tar['vertices'][:, 2:] + vertices_tar['vertices'][:, :-2] - 2 * vertices_tar['vertices'][:, 1:-1]) * self.args.rec_weight * self.args.rec_ver_weight
self.tracker.update_meter("ver_vel", "train", vert_velocity_loss.item())
self.tracker.update_meter("ver_acc", "train", vert_acceleration_loss.item())
g_loss_final += vert_velocity_loss
g_loss_final += vert_acceleration_loss
# ---------------------- vae -------------------------- #
if "VQVAE" in self.args.g_name:
loss_embedding = net_out["embedding_loss"]
g_loss_final += loss_embedding
self.tracker.update_meter("com", "train", loss_embedding.item())
# elif "VAE" in self.args.g_name:
# pose_mu, pose_logvar = net_out["pose_mu"], net_out["pose_logvar"]
# KLD = -0.5 * torch.sum(1 + pose_logvar - pose_mu.pow(2) - pose_logvar.exp())
# if epoch < 0:
# KLD_weight = 0
# else:
# KLD_weight = min(1.0, (epoch - 0) * 0.05) * 0.01
# loss += KLD_weight * KLD
# self.tracker.update_meter("kl", "train", KLD_weight * KLD.item())
g_loss_final.backward()
if self.args.grad_norm != 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_norm)
self.opt.step()
t_train = time.time() - t_start - t_data
t_start = time.time()
mem_cost = torch.cuda.memory_cached() / 1E9
lr_g = self.opt.param_groups[0]['lr']
if its % self.args.log_period == 0:
self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g)
if self.args.debug:
if its == 1: break
self.opt_s.step(epoch)
def val(self, epoch):
self.model.eval()
t_start = time.time()
with torch.no_grad():
for its, dict_data in enumerate(self.val_loader):
tar_pose = dict_data["pose"]
tar_beta = dict_data["beta"].cuda()
tar_trans = dict_data["trans"].cuda()
tar_pose = tar_pose.cuda()
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
tar_exps = dict_data["facial"].to(self.rank)
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3))
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
in_tar_pose = torch.cat([tar_pose, tar_exps], -1) # 103
# print(tar_pose.shape, in_tar_pose.shape, tar_exps.shape)
t_data = time.time() - t_start
#self.opt.zero_grad()
#g_loss_final = 0
net_out = self.model(in_tar_pose)
# jaw open 6d loss
rec_pose = net_out["rec_pose"][:, :, :j*6]
rec_pose = rec_pose.reshape(bs, n, j, 6)
rec_pose = rc.rotation_6d_to_matrix(rec_pose)#
tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6))
loss_rec = self.rec_loss(rec_pose, tar_pose) * self.args.rec_weight * self.args.rec_pos_weight
self.tracker.update_meter("rec", "val", loss_rec.item())
# g_loss_final += loss_rec
# jaw open 6d vel and acc loss
velocity_loss = self.vel_loss(rec_pose[:, 1:] - rec_pose[:, :-1], tar_pose[:, 1:] - tar_pose[:, :-1]) * self.args.rec_weight
acceleration_loss = self.vel_loss(rec_pose[:, 2:] + rec_pose[:, :-2] - 2 * rec_pose[:, 1:-1], tar_pose[:, 2:] + tar_pose[:, :-2] - 2 * tar_pose[:, 1:-1]) * self.args.rec_weight
self.tracker.update_meter("vel", "val", velocity_loss.item())
self.tracker.update_meter("acc", "val", acceleration_loss.item())
# g_loss_final += velocity_loss
# g_loss_final += acceleration_loss
# face parameter l1 loss
rec_exps = net_out["rec_pose"][:, :, j*6:]
loss_face = self.vel_loss(rec_exps, tar_exps) * self.args.rec_weight
self.tracker.update_meter("face", "val", loss_face.item())
# g_loss_final += loss_face
# face parameter l1 vel and acc loss
face_velocity_loss = self.vel_loss(rec_exps[:, 1:] - rec_exps[:, :-1], tar_exps[:, 1:] - tar_exps[:, :-1]) * self.args.rec_weight
face_acceleration_loss = self.vel_loss(rec_exps[:, 2:] + rec_exps[:, :-2] - 2 * rec_exps[:, 1:-1], tar_exps[:, 2:] + tar_exps[:, :-2] - 2 * tar_exps[:, 1:-1]) * self.args.rec_weight
self.tracker.update_meter("face_vel", "val", face_velocity_loss.item())
self.tracker.update_meter("face_acc", "val", face_acceleration_loss.item())
# g_loss_final += face_velocity_loss
# g_loss_final += face_acceleration_loss
# vertices loss
if self.args.rec_ver_weight > 0:
tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3)
rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3)
vertices_rec = self.smplx(
betas=tar_beta.reshape(bs*n, 300),
transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3),
expression=tar_exps.reshape(bs*n, 100),
jaw_pose=rec_pose,
global_orient=torch.zeros(bs*n, 3).cuda(),
body_pose=torch.zeros(bs*n, 21*3).cuda(),
left_hand_pose=torch.zeros(bs*n, 15*3).cuda(),
right_hand_pose=torch.zeros(bs*n, 15*3).cuda(),
return_verts=True,
# return_joints=True,
leye_pose=torch.zeros(bs*n, 3).cuda(),
reye_pose=torch.zeros(bs*n, 3).cuda(),
)
vertices_tar = self.smplx(
betas=tar_beta.reshape(bs*n, 300),
transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3),
expression=rec_exps.reshape(bs*n, 100),
jaw_pose=tar_pose,
global_orient=torch.zeros(bs*n, 3).cuda(),
body_pose=torch.zeros(bs*n, 21*3).cuda(),
left_hand_pose=torch.zeros(bs*n, 15*3).cuda(),
right_hand_pose=torch.zeros(bs*n, 15*3).cuda(),
return_verts=True,
# return_joints=True,
leye_pose=torch.zeros(bs*n, 3).cuda(),
reye_pose=torch.zeros(bs*n, 3).cuda(),
)
vectices_loss = self.mse_loss(vertices_rec['vertices'], vertices_tar['vertices'])
self.tracker.update_meter("ver", "val", vectices_loss.item()*self.args.rec_weight * self.args.rec_ver_weight)
# g_loss_final += vectices_loss*self.args.rec_weight*self.args.rec_ver_weight
# vertices vel and acc loss
vert_velocity_loss = self.vel_loss(vertices_rec['vertices'][:, 1:] - vertices_rec['vertices'][:, :-1], vertices_tar['vertices'][:, 1:] - vertices_tar['vertices'][:, :-1]) * self.args.rec_weight * self.args.rec_ver_weight
vert_acceleration_loss = self.vel_loss(vertices_rec['vertices'][:, 2:] + vertices_rec['vertices'][:, :-2] - 2 * vertices_rec['vertices'][:, 1:-1], vertices_tar['vertices'][:, 2:] + vertices_tar['vertices'][:, :-2] - 2 * vertices_tar['vertices'][:, 1:-1]) * self.args.rec_weight * self.args.rec_ver_weight
self.tracker.update_meter("ver_vel", "val", vert_velocity_loss.item())
self.tracker.update_meter("ver_acc", "val", vert_acceleration_loss.item())
# g_loss_final += vert_velocity_loss
# g_loss_final += vert_acceleration_loss
if "VQVAE" in self.args.g_name:
loss_embedding = net_out["embedding_loss"]
self.tracker.update_meter("com", "val", loss_embedding.item())
#g_loss_final += vectices_loss*self.args.rec_weight*self.args.rec_ver_weight
self.val_recording(epoch)
def test(self, epoch):
results_save_path = self.checkpoint_path + f"/{epoch}/"
if os.path.exists(results_save_path):
return 0
os.makedirs(results_save_path)
start_time = time.time()
total_length = 0
test_seq_list = self.test_data.selected_file
self.model.eval()
with torch.no_grad():
for its, dict_data in enumerate(self.test_loader):
tar_pose = dict_data["pose"]
tar_pose = tar_pose.cuda()
tar_exps = dict_data["facial"].to(self.rank)
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3))
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
remain = n%self.args.pose_length
tar_pose = tar_pose[:, :n-remain, :]
# print(tar_exps.shape)
in_tar_pose = torch.cat([tar_pose, tar_exps[:, :n-remain, :]], -1) # 103
#print(tar_pose.shape)
if True:
net_out = self.model(in_tar_pose)
rec_pose = net_out["rec_pose"][:, :, :j*6]
n = rec_pose.shape[1]
tar_pose = tar_pose[:, :n, :]
rec_pose = rec_pose.reshape(bs, n, j, 6)
rec_pose = rc.rotation_6d_to_matrix(rec_pose)#
rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3)
rec_pose = rec_pose.cpu().numpy()
rec_exps = net_out["rec_pose"][:, :, j*6:]
rec_exps = rec_exps.cpu().numpy().reshape(bs*n, 100)
else:
pass
# for i in range(tar_pose.shape[1]//(self.args.vae_test_len)):
# tar_pose_new = tar_pose[:,i*(self.args.vae_test_len):i*(self.args.vae_test_len)+self.args.vae_test_len,:]
# net_out = self.model(**dict(inputs=tar_pose_new))
# rec_pose = net_out["rec_pose"]
# rec_pose = (rec_pose.reshape(rec_pose.shape[0], rec_pose.shape[1], -1, 6) * self.joint_level_mask_cuda).reshape(rec_pose.shape[0], rec_pose.shape[1], -1)
# if "rot6d" in self.args.pose_rep:
# rec_pose = data_transfer.rotation_6d_to_matrix(rec_pose.reshape(tar_pose.shape[0], self.args.vae_test_len, -1, 6))
# rec_pose = data_transfer.matrix_to_euler_angles(rec_pose, "XYZ").reshape(rec_pose.shape[0], rec_pose.shape[1], -1)
# if "smplx" not in self.args.pose_rep:
# rec_pose = torch.rad2deg(rec_pose)
# rec_pose = rec_pose * self.joint_mask_cuda
# out_sub = rec_pose.cpu().numpy().reshape(-1, rec_pose.shape[2])
# if i != 0:
# out_final = np.concatenate((out_final,out_sub), 0)
# else:
# out_final = out_sub
tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6))
tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3)
tar_pose = tar_pose.cpu().numpy()
total_length += n
# --- save --- #
if 'smplx' in self.args.pose_rep:
gt_npz = np.load(self.args.data_path+self.args.pose_rep+"/"+test_seq_list.iloc[its]['id']+'.npz', allow_pickle=True)
stride = int(30 / self.args.pose_fps)
tar_pose = self.inverse_selection(tar_pose, self.test_data.joint_mask, tar_pose.shape[0])
np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz',
betas=gt_npz["betas"],
poses=tar_pose[:n],
expressions=gt_npz["expressions"],
trans=gt_npz["trans"][::stride][:n] - gt_npz["trans"][::stride][:n],
model='smplx2020',
gender='neutral',
mocap_frame_rate = 30 ,
)
rec_pose = self.inverse_selection(rec_pose, self.test_data.joint_mask, rec_pose.shape[0])
np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz',
betas=gt_npz["betas"],
poses=rec_pose,
expressions=rec_exps,
trans=gt_npz["trans"][::stride][:n] - gt_npz["trans"][::stride][:n],
model='smplx2020',
gender='neutral',
mocap_frame_rate = 30 ,
)
else:
rec_pose = rc.axis_angle_to_matrix(torch.from_numpy(rec_pose.reshape(bs*n, j, 3)))
rec_pose = np.rad2deg(rc.matrix_to_euler_angles(rec_pose, "XYZ")).reshape(bs*n, j*3).numpy()
tar_pose = rc.axis_angle_to_matrix(torch.from_numpy(tar_pose.reshape(bs*n, j, 3)))
tar_pose = np.rad2deg(rc.matrix_to_euler_angles(tar_pose, "XYZ")).reshape(bs*n, j*3).numpy()
#trans="0.000000 0.000000 0.000000"
with open(f"{self.args.data_path}{self.args.pose_rep}/{test_seq_list.iloc[its]['id']}.bvh", "r") as f_demo:
with open(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.bvh', 'w+') as f_gt:
with open(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.bvh', 'w+') as f_real:
for i, line_data in enumerate(f_demo.readlines()):
if i < 431:
f_real.write(line_data)
f_gt.write(line_data)
else: break
for line_id in range(n): #,args.pre_frames, args.pose_length
line_data = np.array2string(rec_pose[line_id], max_line_width=np.inf, precision=6, suppress_small=False, separator=' ')
f_real.write(line_data[1:-2]+'\n')
for line_id in range(n): #,args.pre_frames, args.pose_length
line_data = np.array2string(tar_pose[line_id], max_line_width=np.inf, precision=6, suppress_small=False, separator=' ')
f_gt.write(line_data[1:-2]+'\n')
# with open(results_save_path+"gt_"+test_seq_list[its]+'.pkl', 'wb') as fw:
# pickle.dump(new_dict, fw)
# #new_dict2["fullpose"] = out_final
# with open(results_save_path+"res_"+test_seq_list[its]+'.pkl', 'wb') as fw1:
# pickle.dump(new_dict2, fw1)
# other_tools.render_one_sequence(
# results_save_path+"res_"+test_seq_list[its]+'.pkl',
# results_save_path+"gt_"+test_seq_list[its]+'.pkl',
# results_save_path,
# self.args.data_path + self.args.test_data_path + 'wave16k/' + test_seq_list[its]+'.npy',
# )
#if its == 1:break
end_time = time.time() - start_time
logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion")