MoMask / edit_t2m.py
MeYourHint's picture
first demo version
c0eac48
import os
from os.path import join as pjoin
import torch
import torch.nn.functional as F
from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer
from models.vq.model import RVQVAE, LengthEstimator
from options.eval_option import EvalT2MOptions
from utils.get_opt import get_opt
from utils.fixseed import fixseed
from visualization.joints2bvh import Joint2BVHConvertor
from utils.motion_process import recover_from_ric
from utils.plot_script import plot_3d_motion
from utils.paramUtil import t2m_kinematic_chain
import numpy as np
from gen_t2m import load_vq_model, load_res_model, load_trans_model
if __name__ == '__main__':
parser = EvalT2MOptions()
opt = parser.parse()
fixseed(opt.seed)
opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id))
torch.autograd.set_detect_anomaly(True)
dim_pose = 251 if opt.dataset_name == 'kit' else 263
root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
model_dir = pjoin(root_dir, 'model')
result_dir = pjoin('./editing', opt.ext)
joints_dir = pjoin(result_dir, 'joints')
animation_dir = pjoin(result_dir, 'animations')
os.makedirs(joints_dir, exist_ok=True)
os.makedirs(animation_dir,exist_ok=True)
model_opt_path = pjoin(root_dir, 'opt.txt')
model_opt = get_opt(model_opt_path, device=opt.device)
#######################
######Loading RVQ######
#######################
vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt')
vq_opt = get_opt(vq_opt_path, device=opt.device)
vq_opt.dim_pose = dim_pose
vq_model, vq_opt = load_vq_model(vq_opt)
model_opt.num_tokens = vq_opt.nb_code
model_opt.num_quantizers = vq_opt.num_quantizers
model_opt.code_dim = vq_opt.code_dim
#################################
######Loading R-Transformer######
#################################
res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt')
res_opt = get_opt(res_opt_path, device=opt.device)
res_model = load_res_model(res_opt, vq_opt, opt)
assert res_opt.vq_name == model_opt.vq_name
#################################
######Loading M-Transformer######
#################################
t2m_transformer = load_trans_model(model_opt, opt, 'latest.tar')
t2m_transformer.eval()
vq_model.eval()
res_model.eval()
res_model.to(opt.device)
t2m_transformer.to(opt.device)
vq_model.to(opt.device)
##### ---- Data ---- #####
max_motion_length = 196
mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'mean.npy'))
std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'std.npy'))
def inv_transform(data):
return data * std + mean
### We provided an example source motion (from 'new_joint_vecs') for editing. See './example_data/000612.mp4'###
motion = np.load(opt.source_motion)
m_length = len(motion)
motion = (motion - mean) / std
if max_motion_length > m_length:
motion = np.concatenate([motion, np.zeros((max_motion_length - m_length, motion.shape[1])) ], axis=0)
motion = torch.from_numpy(motion)[None].to(opt.device)
prompt_list = []
length_list = []
if opt.motion_length == 0:
opt.motion_length = m_length
print("Using default motion length.")
prompt_list.append(opt.text_prompt)
length_list.append(opt.motion_length)
if opt.text_prompt == "":
raise "Using an empty text prompt."
token_lens = torch.LongTensor(length_list) // 4
token_lens = token_lens.to(opt.device).long()
m_length = token_lens * 4
captions = prompt_list
print_captions = captions[0]
_edit_slice = opt.mask_edit_section
edit_slice = []
for eds in _edit_slice:
_start, _end = eds.split(',')
_start = eval(_start)
_end = eval(_end)
edit_slice.append([_start, _end])
sample = 0
kinematic_chain = t2m_kinematic_chain
converter = Joint2BVHConvertor()
with torch.no_grad():
tokens, features = vq_model.encode(motion)
### build editing mask, TOEDIT marked as 1 ###
edit_mask = torch.zeros_like(tokens[..., 0])
seq_len = tokens.shape[1]
for _start, _end in edit_slice:
if isinstance(_start, float):
_start = int(_start*seq_len)
_end = int(_end*seq_len)
else:
_start //= 4
_end //= 4
edit_mask[:, _start: _end] = 1
print_captions = f'{print_captions} [{_start*4/20.}s - {_end*4/20.}s]'
edit_mask = edit_mask.bool()
for r in range(opt.repeat_times):
print("-->Repeat %d"%r)
with torch.no_grad():
mids = t2m_transformer.edit(
captions, tokens[..., 0].clone(), m_length//4,
timesteps=opt.time_steps,
cond_scale=opt.cond_scale,
temperature=opt.temperature,
topk_filter_thres=opt.topkr,
gsample=opt.gumbel_sample,
force_mask=opt.force_mask,
edit_mask=edit_mask.clone(),
)
if opt.use_res_model:
mids = res_model.generate(mids, captions, m_length//4, temperature=1, cond_scale=2)
else:
mids.unsqueeze_(-1)
pred_motions = vq_model.forward_decoder(mids)
pred_motions = pred_motions.detach().cpu().numpy()
source_motions = motion.detach().cpu().numpy()
data = inv_transform(pred_motions)
source_data = inv_transform(source_motions)
for k, (caption, joint_data, source_data) in enumerate(zip(captions, data, source_data)):
print("---->Sample %d: %s %d"%(k, caption, m_length[k]))
animation_path = pjoin(animation_dir, str(k))
joint_path = pjoin(joints_dir, str(k))
os.makedirs(animation_path, exist_ok=True)
os.makedirs(joint_path, exist_ok=True)
joint_data = joint_data[:m_length[k]]
joint = recover_from_ric(torch.from_numpy(joint_data).float(), 22).numpy()
source_data = source_data[:m_length[k]]
soucre_joint = recover_from_ric(torch.from_numpy(source_data).float(), 22).numpy()
bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.bvh"%(k, r, m_length[k]))
_, ik_joint = converter.convert(joint, filename=bvh_path, iterations=100)
bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d.bvh" % (k, r, m_length[k]))
_, joint = converter.convert(joint, filename=bvh_path, iterations=100, foot_ik=False)
save_path = pjoin(animation_path, "sample%d_repeat%d_len%d.mp4"%(k, r, m_length[k]))
ik_save_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.mp4"%(k, r, m_length[k]))
source_save_path = pjoin(animation_path, "sample%d_source_len%d.mp4"%(k, m_length[k]))
plot_3d_motion(ik_save_path, kinematic_chain, ik_joint, title=print_captions, fps=20)
plot_3d_motion(save_path, kinematic_chain, joint, title=print_captions, fps=20)
plot_3d_motion(source_save_path, kinematic_chain, soucre_joint, title='None', fps=20)
np.save(pjoin(joint_path, "sample%d_repeat%d_len%d.npy"%(k, r, m_length[k])), joint)
np.save(pjoin(joint_path, "sample%d_repeat%d_len%d_ik.npy"%(k, r, m_length[k])), ik_joint)