EMAGE / ae_trainer.py
H-Liu1997's picture
Upload folder using huggingface_hub
2d47d90 verified
import train
import os
import time
import csv
import sys
import warnings
import random
import numpy as np
import time
import pprint
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP
from loguru import logger
import smplx
from utils import config, logger_tools, other_tools, metric
from utils import rotation_conversions as rc
from dataloaders import data_tools
from optimizers.optim_factory import create_optimizer
from optimizers.scheduler_factory import create_scheduler
from optimizers.loss_factory import get_loss_func
from scipy.spatial.transform import Rotation
class CustomTrainer(train.BaseTrainer):
"""
motion representation learning
"""
def __init__(self, args):
super().__init__(args)
self.joints = self.train_data.joints
self.smplx = smplx.create(
self.args.data_path_1+"smplx_models/",
model_type='smplx',
gender='NEUTRAL_2020',
use_face_contour=False,
num_betas=300,
num_expression_coeffs=100,
ext='npz',
use_pca=False,
).cuda().eval()
self.tracker = other_tools.EpochTracker(["rec", "vel", "ver", "com", "kl", "acc"], [False, False, False, False, False, False])
if not self.args.rot6d: #"rot6d" not in args.pose_rep:
logger.error(f"this script is for rot6d, your pose rep. is {args.pose_rep}")
self.rec_loss = get_loss_func("GeodesicLoss")
self.vel_loss = torch.nn.L1Loss(reduction='mean')
self.vectices_loss = torch.nn.MSELoss(reduction='mean')
def inverse_selection(self, filtered_t, selection_array, n):
# 创建一个全为零的数组,形状为 n*165
original_shape_t = np.zeros((n, selection_array.size))
# 找到选择数组中为1的索引位置
selected_indices = np.where(selection_array == 1)[0]
# 将 filtered_t 的值填充到 original_shape_t 中相应的位置
for i in range(n):
original_shape_t[i, selected_indices] = filtered_t[i]
return original_shape_t
def inverse_selection_tensor(self, filtered_t, selection_array, n):
# 创建一个全为零的数组,形状为 n*165
selection_array = torch.from_numpy(selection_array).cuda()
original_shape_t = torch.zeros((n, 165)).cuda()
# 找到选择数组中为1的索引位置
selected_indices = torch.where(selection_array == 1)[0]
# 将 filtered_t 的值填充到 original_shape_t 中相应的位置
for i in range(n):
original_shape_t[i, selected_indices] = filtered_t[i]
return original_shape_t
def train(self, epoch):
self.model.train()
t_start = time.time()
self.tracker.reset()
for its, dict_data in enumerate(self.train_loader):
tar_pose = dict_data["pose"]
tar_beta = dict_data["beta"].cuda()
tar_trans = dict_data["trans"].cuda()
tar_pose = tar_pose.cuda()
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
tar_exps = torch.zeros((bs, n, 100)).cuda()
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3))
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
t_data = time.time() - t_start
self.opt.zero_grad()
g_loss_final = 0
net_out = self.model(tar_pose)
rec_pose = net_out["rec_pose"]
rec_pose = rec_pose.reshape(bs, n, j, 6)
rec_pose = rc.rotation_6d_to_matrix(rec_pose)#
tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6))
loss_rec = self.rec_loss(rec_pose, tar_pose) * self.args.rec_weight * self.args.rec_pos_weight
self.tracker.update_meter("rec", "train", loss_rec.item())
g_loss_final += loss_rec
velocity_loss = self.vel_loss(rec_pose[:, 1:] - rec_pose[:, :-1], tar_pose[:, 1:] - tar_pose[:, :-1]) * self.args.rec_weight
acceleration_loss = self.vel_loss(rec_pose[:, 2:] + rec_pose[:, :-2] - 2 * rec_pose[:, 1:-1], tar_pose[:, 2:] + tar_pose[:, :-2] - 2 * tar_pose[:, 1:-1]) * self.args.rec_weight
self.tracker.update_meter("vel", "train", velocity_loss.item())
self.tracker.update_meter("acc", "train", acceleration_loss.item())
g_loss_final += velocity_loss
g_loss_final += acceleration_loss
# vertices loss
if self.args.rec_ver_weight > 0:
tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3)
rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3)
rec_pose = self.inverse_selection_tensor(rec_pose, self.train_data.joint_mask, rec_pose.shape[0])
tar_pose = self.inverse_selection_tensor(tar_pose, self.train_data.joint_mask, tar_pose.shape[0])
vertices_rec = self.smplx(
betas=tar_beta.reshape(bs*n, 300),
transl=tar_trans.reshape(bs*n, 3),
expression=tar_exps.reshape(bs*n, 100),
jaw_pose=rec_pose[:, 66:69],
global_orient=rec_pose[:,:3],
body_pose=rec_pose[:,3:21*3+3],
left_hand_pose=rec_pose[:,25*3:40*3],
right_hand_pose=rec_pose[:,40*3:55*3],
return_verts=True,
return_joints=True,
leye_pose=tar_pose[:, 69:72],
reye_pose=tar_pose[:, 72:75],
)
vertices_tar = self.smplx(
betas=tar_beta.reshape(bs*n, 300),
transl=tar_trans.reshape(bs*n, 3),
expression=tar_exps.reshape(bs*n, 100),
jaw_pose=tar_pose[:, 66:69],
global_orient=tar_pose[:,:3],
body_pose=tar_pose[:,3:21*3+3],
left_hand_pose=tar_pose[:,25*3:40*3],
right_hand_pose=tar_pose[:,40*3:55*3],
return_verts=True,
return_joints=True,
leye_pose=tar_pose[:, 69:72],
reye_pose=tar_pose[:, 72:75],
)
vectices_loss = self.vectices_loss(vertices_rec['vertices'], vertices_tar['vertices'])
self.tracker.update_meter("ver", "train", vectices_loss.item()*self.args.rec_weight * self.args.rec_ver_weight)
g_loss_final += vectices_loss*self.args.rec_weight*self.args.rec_ver_weight
vertices_vel_loss = self.vel_loss(vertices_rec['vertices'][:, 1:] - vertices_rec['vertices'][:, :-1], vertices_tar['vertices'][:, 1:] - vertices_tar['vertices'][:, :-1]) * self.args.rec_weight
vertices_acc_loss = self.vel_loss(vertices_rec['vertices'][:, 2:] + vertices_rec['vertices'][:, :-2] - 2 * vertices_rec['vertices'][:, 1:-1], vertices_tar['vertices'][:, 2:] + vertices_tar['vertices'][:, :-2] - 2 * vertices_tar['vertices'][:, 1:-1]) * self.args.rec_weight
g_loss_final += vertices_vel_loss * self.args.rec_weight * self.args.rec_ver_weight
g_loss_final += vertices_acc_loss * self.args.rec_weight * self.args.rec_ver_weight
# if self.args.vel_weight > 0:
# pos_rec_vel = other_tools.estimate_linear_velocity(vertices_rec['joints'], 1/self.pose_fps)
# pos_tar_vel = other_tools.estimate_linear_velocity(vertices_tar['joints'], 1/self.pose_fps)
# vel_rec_loss = self.vel_loss(pos_rec_vel, pos_tar_vel)
# tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3))
# rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs, n, j, 3))
# rot_rec_vel = other_tools.estimate_angular_velocity(rec_pose, 1/self.pose_fps)
# rot_tar_vel = other_tools.estimate_angular_velocity(tar_pose, 1/self.pose_fps)
# vel_rec_loss += self.vel_loss(pos_rec_vel, pos_tar_vel)
# self.tracker.update_meter("vel", "train", vel_rec_loss.item()*self.args.vel_weight)
# loss += (vel_rec_loss*self.args.vel_weight)
# ---------------------- vae -------------------------- #
if "VQVAE" in self.args.g_name:
loss_embedding = net_out["embedding_loss"]
g_loss_final += loss_embedding
self.tracker.update_meter("com", "train", loss_embedding.item())
# elif "VAE" in self.args.g_name:
# pose_mu, pose_logvar = net_out["pose_mu"], net_out["pose_logvar"]
# KLD = -0.5 * torch.sum(1 + pose_logvar - pose_mu.pow(2) - pose_logvar.exp())
# if epoch < 0:
# KLD_weight = 0
# else:
# KLD_weight = min(1.0, (epoch - 0) * 0.05) * 0.01
# loss += KLD_weight * KLD
# self.tracker.update_meter("kl", "train", KLD_weight * KLD.item())
g_loss_final.backward()
if self.args.grad_norm != 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_norm)
self.opt.step()
t_train = time.time() - t_start - t_data
t_start = time.time()
mem_cost = torch.cuda.memory_cached() / 1E9
lr_g = self.opt.param_groups[0]['lr']
if its % self.args.log_period == 0:
self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g)
if self.args.debug:
if its == 1: break
self.opt_s.step(epoch)
def val(self, epoch):
self.model.eval()
t_start = time.time()
with torch.no_grad():
for its, dict_data in enumerate(self.val_loader):
tar_pose = dict_data["pose"]
tar_beta = dict_data["beta"].cuda()
tar_trans = dict_data["trans"].cuda()
tar_pose = tar_pose.cuda()
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
tar_exps = torch.zeros((bs, n, 100)).cuda()
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3))
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
t_data = time.time() - t_start
#self.opt.zero_grad()
#g_loss_final = 0
net_out = self.model(tar_pose)
rec_pose = net_out["rec_pose"]
rec_pose = rec_pose.reshape(bs, n, j, 6)
rec_pose = rc.rotation_6d_to_matrix(rec_pose)#
tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6))
loss_rec = self.rec_loss(rec_pose, tar_pose) * self.args.rec_weight * self.args.rec_pos_weight
self.tracker.update_meter("rec", "val", loss_rec.item())
#g_loss_final += loss_rec
# vertices loss
if self.args.rec_ver_weight > 0:
tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3)
rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3)
rec_pose = self.inverse_selection_tensor(rec_pose, self.train_data.joint_mask, rec_pose.shape[0])
tar_pose = self.inverse_selection_tensor(tar_pose, self.train_data.joint_mask, tar_pose.shape[0])
vertices_rec = self.smplx(
betas=tar_beta.reshape(bs*n, 300),
transl=tar_trans.reshape(bs*n, 3),
expression=tar_exps.reshape(bs*n, 100),
jaw_pose=rec_pose[:, 66:69],
global_orient=rec_pose[:,:3],
body_pose=rec_pose[:,3:21*3+3],
left_hand_pose=rec_pose[:,25*3:40*3],
right_hand_pose=rec_pose[:,40*3:55*3],
return_verts=True,
leye_pose=tar_pose[:, 69:72],
reye_pose=tar_pose[:, 72:75],
)
vertices_tar = self.smplx(
betas=tar_beta.reshape(bs*n, 300),
transl=tar_trans.reshape(bs*n, 3),
expression=tar_exps.reshape(bs*n, 100),
jaw_pose=tar_pose[:, 66:69],
global_orient=tar_pose[:,:3],
body_pose=tar_pose[:,3:21*3+3],
left_hand_pose=tar_pose[:,25*3:40*3],
right_hand_pose=tar_pose[:,40*3:55*3],
return_verts=True,
leye_pose=tar_pose[:, 69:72],
reye_pose=tar_pose[:, 72:75],
)
vectices_loss = self.vectices_loss(vertices_rec['vertices'], vertices_tar['vertices'])
self.tracker.update_meter("ver", "val", vectices_loss.item()*self.args.rec_weight * self.args.rec_ver_weight)
if "VQVAE" in self.args.g_name:
loss_embedding = net_out["embedding_loss"]
self.tracker.update_meter("com", "val", loss_embedding.item())
#g_loss_final += vectices_loss*self.args.rec_weight*self.args.rec_ver_weight
self.val_recording(epoch)
def test(self, epoch):
results_save_path = self.checkpoint_path + f"/{epoch}/"
if os.path.exists(results_save_path):
return 0
os.makedirs(results_save_path)
start_time = time.time()
total_length = 0
test_seq_list = self.test_data.selected_file
self.model.eval()
with torch.no_grad():
for its, dict_data in enumerate(self.test_loader):
tar_pose = dict_data["pose"]
tar_pose = tar_pose.cuda()
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3))
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
remain = n%self.args.pose_length
tar_pose = tar_pose[:, :n-remain, :]
#print(tar_pose.shape)
if True:
net_out = self.model(tar_pose)
rec_pose = net_out["rec_pose"]
n = rec_pose.shape[1]
tar_pose = tar_pose[:, :n, :]
rec_pose = rec_pose.reshape(bs, n, j, 6)
rec_pose = rc.rotation_6d_to_matrix(rec_pose)#
rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3)
rec_pose = rec_pose.cpu().numpy()
else:
pass
# for i in range(tar_pose.shape[1]//(self.args.vae_test_len)):
# tar_pose_new = tar_pose[:,i*(self.args.vae_test_len):i*(self.args.vae_test_len)+self.args.vae_test_len,:]
# net_out = self.model(**dict(inputs=tar_pose_new))
# rec_pose = net_out["rec_pose"]
# rec_pose = (rec_pose.reshape(rec_pose.shape[0], rec_pose.shape[1], -1, 6) * self.joint_level_mask_cuda).reshape(rec_pose.shape[0], rec_pose.shape[1], -1)
# if "rot6d" in self.args.pose_rep:
# rec_pose = data_transfer.rotation_6d_to_matrix(rec_pose.reshape(tar_pose.shape[0], self.args.vae_test_len, -1, 6))
# rec_pose = data_transfer.matrix_to_euler_angles(rec_pose, "XYZ").reshape(rec_pose.shape[0], rec_pose.shape[1], -1)
# if "smplx" not in self.args.pose_rep:
# rec_pose = torch.rad2deg(rec_pose)
# rec_pose = rec_pose * self.joint_mask_cuda
# out_sub = rec_pose.cpu().numpy().reshape(-1, rec_pose.shape[2])
# if i != 0:
# out_final = np.concatenate((out_final,out_sub), 0)
# else:
# out_final = out_sub
tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6))
tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3)
tar_pose = tar_pose.cpu().numpy()
total_length += n
# --- save --- #
if 'smplx' in self.args.pose_rep:
gt_npz = np.load(self.args.data_path+self.args.pose_rep+"/"+test_seq_list.iloc[its]['id']+'.npz', allow_pickle=True)
stride = int(30 / self.args.pose_fps)
tar_pose = self.inverse_selection(tar_pose, self.test_data.joint_mask, tar_pose.shape[0])
np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz',
betas=gt_npz["betas"],
poses=tar_pose[:n],
expressions=gt_npz["expressions"]-gt_npz["expressions"],
trans=gt_npz["trans"][::stride][:n] - gt_npz["trans"][::stride][:n],
model='smplx2020',
gender='neutral',
mocap_frame_rate = 30 ,
)
rec_pose = self.inverse_selection(rec_pose, self.test_data.joint_mask, rec_pose.shape[0])
np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz',
betas=gt_npz["betas"],
poses=rec_pose,
expressions=gt_npz["expressions"]-gt_npz["expressions"],
trans=gt_npz["trans"][::stride][:n] - gt_npz["trans"][::stride][:n],
model='smplx2020',
gender='neutral',
mocap_frame_rate = 30 ,
)
else:
rec_pose = rc.axis_angle_to_matrix(torch.from_numpy(rec_pose.reshape(bs*n, j, 3)))
rec_pose = np.rad2deg(rc.matrix_to_euler_angles(rec_pose, "XYZ")).reshape(bs*n, j*3).numpy()
tar_pose = rc.axis_angle_to_matrix(torch.from_numpy(tar_pose.reshape(bs*n, j, 3)))
tar_pose = np.rad2deg(rc.matrix_to_euler_angles(tar_pose, "XYZ")).reshape(bs*n, j*3).numpy()
#trans="0.000000 0.000000 0.000000"
with open(f"{self.args.data_path}{self.args.pose_rep}/{test_seq_list.iloc[its]['id']}.bvh", "r") as f_demo:
with open(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.bvh', 'w+') as f_gt:
with open(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.bvh', 'w+') as f_real:
for i, line_data in enumerate(f_demo.readlines()):
if i < 431:
f_real.write(line_data)
f_gt.write(line_data)
else: break
for line_id in range(n): #,args.pre_frames, args.pose_length
line_data = np.array2string(rec_pose[line_id], max_line_width=np.inf, precision=6, suppress_small=False, separator=' ')
f_real.write(line_data[1:-2]+'\n')
for line_id in range(n): #,args.pre_frames, args.pose_length
line_data = np.array2string(tar_pose[line_id], max_line_width=np.inf, precision=6, suppress_small=False, separator=' ')
f_gt.write(line_data[1:-2]+'\n')
# with open(results_save_path+"gt_"+test_seq_list[its]+'.pkl', 'wb') as fw:
# pickle.dump(new_dict, fw)
# #new_dict2["fullpose"] = out_final
# with open(results_save_path+"res_"+test_seq_list[its]+'.pkl', 'wb') as fw1:
# pickle.dump(new_dict2, fw1)
# other_tools.render_one_sequence(
# results_save_path+"res_"+test_seq_list[its]+'.pkl',
# results_save_path+"gt_"+test_seq_list[its]+'.pkl',
# results_save_path,
# self.args.data_path + self.args.test_data_path + 'wave16k/' + test_seq_list[its]+'.npy',
# )
#if its == 1:break
end_time = time.time() - start_time
logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion")