|
import torch |
|
import clip |
|
import models.vqvae as vqvae |
|
from models.vqvae_sep import VQVAE_SEP |
|
import models.t2m_trans as trans |
|
import models.t2m_trans_uplow as trans_uplow |
|
import numpy as np |
|
from exit.utils import visualize_2motions |
|
import options.option_transformer as option_trans |
|
|
|
|
|
|
|
|
|
clip_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cpu'), jit=False) |
|
clip.model.convert_weights(clip_model) |
|
clip_model.eval() |
|
for p in clip_model.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
class TextCLIP(torch.nn.Module): |
|
def __init__(self, model) : |
|
super(TextCLIP, self).__init__() |
|
self.model = model |
|
|
|
def forward(self,text): |
|
with torch.no_grad(): |
|
word_emb = self.model.token_embedding(text).type(self.model.dtype) |
|
word_emb = word_emb + self.model.positional_embedding.type(self.model.dtype) |
|
word_emb = word_emb.permute(1, 0, 2) |
|
word_emb = self.model.transformer(word_emb) |
|
word_emb = self.model.ln_final(word_emb).permute(1, 0, 2).float() |
|
enctxt = self.model.encode_text(text).float() |
|
return enctxt, word_emb |
|
clip_model = TextCLIP(clip_model) |
|
|
|
def get_vqvae(args, is_upper_edit): |
|
if not is_upper_edit: |
|
return vqvae.HumanVQVAE(args, |
|
args.nb_code, |
|
args.code_dim, |
|
args.output_emb_width, |
|
args.down_t, |
|
args.stride_t, |
|
args.width, |
|
args.depth, |
|
args.dilation_growth_rate) |
|
else: |
|
return VQVAE_SEP(args, |
|
args.nb_code, |
|
args.code_dim, |
|
args.output_emb_width, |
|
args.down_t, |
|
args.stride_t, |
|
args.width, |
|
args.depth, |
|
args.dilation_growth_rate, |
|
moment={'mean': torch.from_numpy(args.mean).float(), |
|
'std': torch.from_numpy(args.std).float()}, |
|
sep_decoder=True) |
|
|
|
def get_maskdecoder(args, vqvae, is_upper_edit): |
|
tranformer = trans if not is_upper_edit else trans_uplow |
|
return tranformer.Text2Motion_Transformer(vqvae, |
|
num_vq=args.nb_code, |
|
embed_dim=args.embed_dim_gpt, |
|
clip_dim=args.clip_dim, |
|
block_size=args.block_size, |
|
num_layers=args.num_layers, |
|
num_local_layer=args.num_local_layer, |
|
n_head=args.n_head_gpt, |
|
drop_out_rate=args.drop_out_rate, |
|
fc_rate=args.ff_rate) |
|
|
|
class MMM(torch.nn.Module): |
|
def __init__(self, args=None, is_upper_edit=False): |
|
super().__init__() |
|
self.is_upper_edit = is_upper_edit |
|
|
|
|
|
args.dataname = args.dataset_name = 't2m' |
|
|
|
self.vqvae = get_vqvae(args, is_upper_edit) |
|
ckpt = torch.load(args.resume_pth, map_location='cpu') |
|
self.vqvae.load_state_dict(ckpt['net'], strict=True) |
|
if is_upper_edit: |
|
class VQVAE_WRAPPER(torch.nn.Module): |
|
def __init__(self, vqvae) : |
|
super().__init__() |
|
self.vqvae = vqvae |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.vqvae(*args, **kwargs) |
|
self.vqvae = VQVAE_WRAPPER(self.vqvae) |
|
self.vqvae.eval() |
|
self.vqvae |
|
|
|
self.maskdecoder = get_maskdecoder(args, self.vqvae, is_upper_edit) |
|
ckpt = torch.load(args.resume_trans, map_location='cpu') |
|
self.maskdecoder.load_state_dict(ckpt['trans'], strict=True) |
|
self.maskdecoder.train() |
|
self.maskdecoder |
|
|
|
def forward(self, text, lengths=-1, rand_pos=True): |
|
b = len(text) |
|
feat_clip_text = clip.tokenize(text, truncate=True) |
|
feat_clip_text, word_emb = clip_model(feat_clip_text) |
|
index_motion = self.maskdecoder(feat_clip_text, word_emb, type="sample", m_length=lengths, rand_pos=rand_pos, if_test=False) |
|
|
|
m_token_length = torch.ceil((lengths)/4).int() |
|
pred_pose_all = torch.zeros((b, 196, 263)) |
|
for k in range(b): |
|
pred_pose = self.vqvae(index_motion[k:k+1, :m_token_length[k]], type='decode') |
|
pred_pose_all[k:k+1, :int(lengths[k].item())] = pred_pose |
|
return pred_pose_all |
|
|
|
def inbetween_eval(self, base_pose, m_length, start_f, end_f, inbetween_text): |
|
bs, seq = base_pose.shape[:2] |
|
tokens = -1*torch.ones((bs, 50), dtype=torch.long) |
|
m_token_length = torch.ceil((m_length)/4).int() |
|
start_t = torch.round((start_f)/4).int() |
|
end_t = torch.round((end_f)/4).int() |
|
|
|
for k in range(bs): |
|
index_motion = self.vqvae(base_pose[k:k+1, :m_length[k]], type='encode') |
|
tokens[k, :start_t[k]] = index_motion[0][:start_t[k]] |
|
tokens[k, end_t[k]:m_token_length[k]] = index_motion[0][end_t[k]:m_token_length[k]] |
|
|
|
text = clip.tokenize(inbetween_text, truncate=True) |
|
feat_clip_text, word_emb_clip = clip_model(text) |
|
|
|
mask_id = self.maskdecoder.num_vq + 2 |
|
tokens[tokens==-1] = mask_id |
|
inpaint_index = self.maskdecoder(feat_clip_text, word_emb_clip, type="sample", m_length=m_length, token_cond=tokens) |
|
|
|
pred_pose_eval = torch.zeros((bs, seq, base_pose.shape[-1])) |
|
for k in range(bs): |
|
pred_pose = self.vqvae(inpaint_index[k:k+1, :m_token_length[k]], type='decode') |
|
pred_pose_eval[k:k+1, :int(m_length[k].item())] = pred_pose |
|
return pred_pose_eval |
|
|
|
def long_range(self, text, lengths, num_transition_token=2, output='concat', index_motion=None): |
|
b = len(text) |
|
feat_clip_text = clip.tokenize(text, truncate=True) |
|
feat_clip_text, word_emb = clip_model(feat_clip_text) |
|
if index_motion is None: |
|
index_motion = self.maskdecoder(feat_clip_text, word_emb, type="sample", m_length=lengths, rand_pos=False) |
|
|
|
m_token_length = torch.ceil((lengths)/4).int() |
|
if output == 'eval': |
|
frame_length = m_token_length * 4 |
|
m_token_length = m_token_length.clone() |
|
m_token_length = m_token_length - 2*num_transition_token |
|
m_token_length[[0,-1]] += num_transition_token |
|
|
|
half_token_length = (m_token_length/2).int() |
|
idx_full_len = half_token_length >= 24 |
|
half_token_length[idx_full_len] = half_token_length[idx_full_len] - 1 |
|
|
|
mask_id = self.maskdecoder.num_vq + 2 |
|
tokens = -1*torch.ones((b-1, 50), dtype=torch.long) |
|
transition_train_length = [] |
|
|
|
for i in range(b-1): |
|
if output == 'concat': |
|
i_index_motion = index_motion[i] |
|
i1_index_motion = index_motion[i+1] |
|
if output == 'eval': |
|
if i == 0: |
|
i_index_motion = index_motion[i, :m_token_length[i]] |
|
else: |
|
i_index_motion = index_motion[i, num_transition_token:m_token_length[i] + num_transition_token] |
|
if i == b-1: |
|
i1_index_motion = index_motion[i+1, :m_token_length[i+1]] |
|
else: |
|
i1_index_motion = index_motion[i+1, |
|
num_transition_token:m_token_length[i+1] + num_transition_token] |
|
left_end = half_token_length[i] |
|
right_start = left_end + num_transition_token |
|
end = right_start + half_token_length[i+1] |
|
|
|
tokens[i, :left_end] = i_index_motion[m_token_length[i]-left_end: m_token_length[i]] |
|
tokens[i, left_end:right_start] = mask_id |
|
tokens[i, right_start:end] = i1_index_motion[:half_token_length[i+1]] |
|
transition_train_length.append(end) |
|
transition_train_length = torch.tensor(transition_train_length).to(index_motion.device) |
|
text = clip.tokenize(text[:-1], truncate=True) |
|
feat_clip_text, word_emb_clip = clip_model(text) |
|
inpaint_index = self.maskdecoder(feat_clip_text, word_emb_clip, type="sample", m_length=transition_train_length*4, token_cond=tokens, max_steps=1) |
|
|
|
if output == 'concat': |
|
all_tokens = [] |
|
for i in range(b-1): |
|
all_tokens.append(index_motion[i, :m_token_length[i]]) |
|
all_tokens.append(inpaint_index[i, tokens[i] == mask_id]) |
|
all_tokens.append(index_motion[-1, :m_token_length[-1]]) |
|
all_tokens = torch.cat(all_tokens).unsqueeze(0) |
|
pred_pose = self.vqvae(all_tokens, type='decode') |
|
return pred_pose |
|
elif output == 'eval': |
|
all_tokens = [] |
|
for i in range(b): |
|
motion_token = index_motion[i, :m_token_length[i]] |
|
if i == 0: |
|
first_current_trans_tok = inpaint_index[i, tokens[i] == mask_id] |
|
all_tokens.append(motion_token) |
|
all_tokens.append(first_current_trans_tok) |
|
else: |
|
if i < b-1: |
|
first_current_trans_tok = inpaint_index[i, tokens[i] == mask_id] |
|
all_tokens.append(motion_token) |
|
all_tokens.append(first_current_trans_tok) |
|
else: |
|
all_tokens.append(motion_token) |
|
all_tokens = torch.cat(all_tokens) |
|
pred_pose_concat = self.vqvae(all_tokens.unsqueeze(0), type='decode') |
|
|
|
trans_frame = num_transition_token*4 |
|
pred_pose = torch.zeros((b, 196, 263)) |
|
current_point = 0 |
|
for i in range(b): |
|
if i == 0: |
|
start_f = torch.tensor(0) |
|
end_f = frame_length[i] |
|
else: |
|
start_f = current_point - trans_frame |
|
end_f = start_f + frame_length[i] |
|
current_point = end_f |
|
pred_pose[i, :frame_length[i]] = pred_pose_concat[0, start_f: end_f] |
|
return pred_pose |
|
|
|
def upper_edit(self, pose, m_length, upper_text, lower_mask=None): |
|
pose = pose.clone().float() |
|
m_tokens_len = torch.ceil((m_length)/4) |
|
bs, seq = pose.shape[:2] |
|
max_motion_length = int(seq/4) + 1 |
|
mot_end_idx = self.vqvae.vqvae.num_code |
|
mot_pad_idx = self.vqvae.vqvae.num_code + 1 |
|
mask_id = self.vqvae.vqvae.num_code + 2 |
|
target_lower = [] |
|
for k in range(bs): |
|
target = self.vqvae(pose[k:k+1, :m_length[k]], type='encode') |
|
if m_tokens_len[k]+1 < max_motion_length: |
|
target = torch.cat([target, |
|
torch.ones((1, 1, 2), dtype=int, device=target.device) * mot_end_idx, |
|
torch.ones((1, max_motion_length-1-m_tokens_len[k].int().item(), 2), dtype=int, device=target.device) * mot_pad_idx], axis=1) |
|
else: |
|
target = torch.cat([target, |
|
torch.ones((1, 1, 2), dtype=int, device=target.device) * mot_end_idx], axis=1) |
|
target_lower.append(target[..., 1]) |
|
target_lower = torch.cat(target_lower, axis=0) |
|
|
|
|
|
if lower_mask is not None: |
|
lower_mask = torch.cat([lower_mask, torch.zeros(bs, 1, dtype=int)], dim=1).bool() |
|
target_lower_masked = target_lower.clone() |
|
target_lower_masked[lower_mask] = mask_id |
|
select_end = target_lower == mot_end_idx |
|
target_lower_masked[select_end] = target_lower[select_end] |
|
else: |
|
target_lower_masked = target_lower |
|
|
|
|
|
pred_len = m_length |
|
pred_tok_len = m_tokens_len |
|
pred_pose_eval = torch.zeros((bs, seq, pose.shape[-1])) |
|
|
|
|
|
text = clip.tokenize(upper_text, truncate=True) |
|
feat_clip_text, word_emb_clip = clip_model(text) |
|
|
|
index_motion = self.maskdecoder(feat_clip_text, target_lower_masked, word_emb_clip, type="sample", m_length=pred_len, rand_pos=True) |
|
for i in range(bs): |
|
all_tokens = torch.cat([ |
|
index_motion[i:i+1, :int(pred_tok_len[i].item()), None], |
|
target_lower[i:i+1, :int(pred_tok_len[i].item()), None] |
|
], axis=-1) |
|
pred_pose = self.vqvae(all_tokens, type='decode') |
|
pred_pose_eval[i:i+1, :int(pred_len[i].item())] = pred_pose |
|
|
|
return pred_pose_eval |
|
|
|
|
|
if __name__ == '__main__': |
|
args = option_trans.get_args_parser() |
|
|
|
|
|
|
|
mmm = MMM(args) |
|
pred_pose = mmm([args.text], torch.tensor([args.length]), rand_pos=False) |
|
|
|
std = np.load('./exit/t2m-std.npy') |
|
mean = np.load('./exit/t2m-mean.npy') |
|
file_name = '_'.join(args.text.split(' '))+'_'+str(args.length) |
|
visualize_2motions(pred_pose[0].detach().cpu().numpy(), std, mean, 't2m', args.length, save_path='./output/'+file_name+'.html') |
|
|
|
|
|
|