MMM-Demo / generate.py
samadi10's picture
Added necessary files
eeaa83d
raw
history blame
14.2 kB
import torch
import clip
import models.vqvae as vqvae
from models.vqvae_sep import VQVAE_SEP
import models.t2m_trans as trans
import models.t2m_trans_uplow as trans_uplow
import numpy as np
from exit.utils import visualize_2motions
import options.option_transformer as option_trans
##### ---- CLIP ---- #####
clip_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cpu'), jit=False) # Must set jit=False for training
clip.model.convert_weights(clip_model) # Actually this line is unnecessary since clip by default already on float16
clip_model.eval()
for p in clip_model.parameters():
p.requires_grad = False
# https://github.com/openai/CLIP/issues/111
class TextCLIP(torch.nn.Module):
def __init__(self, model) :
super(TextCLIP, self).__init__()
self.model = model
def forward(self,text):
with torch.no_grad():
word_emb = self.model.token_embedding(text).type(self.model.dtype)
word_emb = word_emb + self.model.positional_embedding.type(self.model.dtype)
word_emb = word_emb.permute(1, 0, 2) # NLD -> LND
word_emb = self.model.transformer(word_emb)
word_emb = self.model.ln_final(word_emb).permute(1, 0, 2).float()
enctxt = self.model.encode_text(text).float()
return enctxt, word_emb
clip_model = TextCLIP(clip_model)
def get_vqvae(args, is_upper_edit):
if not is_upper_edit:
return vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
args.nb_code,
args.code_dim,
args.output_emb_width,
args.down_t,
args.stride_t,
args.width,
args.depth,
args.dilation_growth_rate)
else:
return VQVAE_SEP(args, ## use args to define different parameters in different quantizers
args.nb_code,
args.code_dim,
args.output_emb_width,
args.down_t,
args.stride_t,
args.width,
args.depth,
args.dilation_growth_rate,
moment={'mean': torch.from_numpy(args.mean).float(),
'std': torch.from_numpy(args.std).float()},
sep_decoder=True)
def get_maskdecoder(args, vqvae, is_upper_edit):
tranformer = trans if not is_upper_edit else trans_uplow
return tranformer.Text2Motion_Transformer(vqvae,
num_vq=args.nb_code,
embed_dim=args.embed_dim_gpt,
clip_dim=args.clip_dim,
block_size=args.block_size,
num_layers=args.num_layers,
num_local_layer=args.num_local_layer,
n_head=args.n_head_gpt,
drop_out_rate=args.drop_out_rate,
fc_rate=args.ff_rate)
class MMM(torch.nn.Module):
def __init__(self, args=None, is_upper_edit=False):
super().__init__()
self.is_upper_edit = is_upper_edit
args.dataname = args.dataset_name = 't2m'
self.vqvae = get_vqvae(args, is_upper_edit)
ckpt = torch.load(args.resume_pth, map_location='cpu')
self.vqvae.load_state_dict(ckpt['net'], strict=True)
if is_upper_edit:
class VQVAE_WRAPPER(torch.nn.Module):
def __init__(self, vqvae) :
super().__init__()
self.vqvae = vqvae
def forward(self, *args, **kwargs):
return self.vqvae(*args, **kwargs)
self.vqvae = VQVAE_WRAPPER(self.vqvae)
self.vqvae.eval()
self.vqvae
self.maskdecoder = get_maskdecoder(args, self.vqvae, is_upper_edit)
ckpt = torch.load(args.resume_trans, map_location='cpu')
self.maskdecoder.load_state_dict(ckpt['trans'], strict=True)
self.maskdecoder.train()
self.maskdecoder
def forward(self, text, lengths=-1, rand_pos=True):
b = len(text)
feat_clip_text = clip.tokenize(text, truncate=True)
feat_clip_text, word_emb = clip_model(feat_clip_text)
index_motion = self.maskdecoder(feat_clip_text, word_emb, type="sample", m_length=lengths, rand_pos=rand_pos, if_test=False)
m_token_length = torch.ceil((lengths)/4).int()
pred_pose_all = torch.zeros((b, 196, 263))
for k in range(b):
pred_pose = self.vqvae(index_motion[k:k+1, :m_token_length[k]], type='decode')
pred_pose_all[k:k+1, :int(lengths[k].item())] = pred_pose
return pred_pose_all
def inbetween_eval(self, base_pose, m_length, start_f, end_f, inbetween_text):
bs, seq = base_pose.shape[:2]
tokens = -1*torch.ones((bs, 50), dtype=torch.long)
m_token_length = torch.ceil((m_length)/4).int()
start_t = torch.round((start_f)/4).int()
end_t = torch.round((end_f)/4).int()
for k in range(bs):
index_motion = self.vqvae(base_pose[k:k+1, :m_length[k]], type='encode')
tokens[k, :start_t[k]] = index_motion[0][:start_t[k]]
tokens[k, end_t[k]:m_token_length[k]] = index_motion[0][end_t[k]:m_token_length[k]]
text = clip.tokenize(inbetween_text, truncate=True)
feat_clip_text, word_emb_clip = clip_model(text)
mask_id = self.maskdecoder.num_vq + 2
tokens[tokens==-1] = mask_id
inpaint_index = self.maskdecoder(feat_clip_text, word_emb_clip, type="sample", m_length=m_length, token_cond=tokens)
pred_pose_eval = torch.zeros((bs, seq, base_pose.shape[-1]))
for k in range(bs):
pred_pose = self.vqvae(inpaint_index[k:k+1, :m_token_length[k]], type='decode')
pred_pose_eval[k:k+1, :int(m_length[k].item())] = pred_pose
return pred_pose_eval
def long_range(self, text, lengths, num_transition_token=2, output='concat', index_motion=None):
b = len(text)
feat_clip_text = clip.tokenize(text, truncate=True)
feat_clip_text, word_emb = clip_model(feat_clip_text)
if index_motion is None:
index_motion = self.maskdecoder(feat_clip_text, word_emb, type="sample", m_length=lengths, rand_pos=False)
m_token_length = torch.ceil((lengths)/4).int()
if output == 'eval':
frame_length = m_token_length * 4
m_token_length = m_token_length.clone()
m_token_length = m_token_length - 2*num_transition_token
m_token_length[[0,-1]] += num_transition_token # first and last have transition only half
half_token_length = (m_token_length/2).int()
idx_full_len = half_token_length >= 24
half_token_length[idx_full_len] = half_token_length[idx_full_len] - 1
mask_id = self.maskdecoder.num_vq + 2
tokens = -1*torch.ones((b-1, 50), dtype=torch.long)
transition_train_length = []
for i in range(b-1):
if output == 'concat':
i_index_motion = index_motion[i]
i1_index_motion = index_motion[i+1]
if output == 'eval':
if i == 0:
i_index_motion = index_motion[i, :m_token_length[i]]
else:
i_index_motion = index_motion[i, num_transition_token:m_token_length[i] + num_transition_token]
if i == b-1:
i1_index_motion = index_motion[i+1, :m_token_length[i+1]]
else:
i1_index_motion = index_motion[i+1,
num_transition_token:m_token_length[i+1] + num_transition_token]
left_end = half_token_length[i]
right_start = left_end + num_transition_token
end = right_start + half_token_length[i+1]
tokens[i, :left_end] = i_index_motion[m_token_length[i]-left_end: m_token_length[i]]
tokens[i, left_end:right_start] = mask_id
tokens[i, right_start:end] = i1_index_motion[:half_token_length[i+1]]
transition_train_length.append(end)
transition_train_length = torch.tensor(transition_train_length).to(index_motion.device)
text = clip.tokenize(text[:-1], truncate=True)
feat_clip_text, word_emb_clip = clip_model(text)
inpaint_index = self.maskdecoder(feat_clip_text, word_emb_clip, type="sample", m_length=transition_train_length*4, token_cond=tokens, max_steps=1)
if output == 'concat':
all_tokens = []
for i in range(b-1):
all_tokens.append(index_motion[i, :m_token_length[i]])
all_tokens.append(inpaint_index[i, tokens[i] == mask_id])
all_tokens.append(index_motion[-1, :m_token_length[-1]])
all_tokens = torch.cat(all_tokens).unsqueeze(0)
pred_pose = self.vqvae(all_tokens, type='decode')
return pred_pose
elif output == 'eval':
all_tokens = []
for i in range(b):
motion_token = index_motion[i, :m_token_length[i]]
if i == 0:
first_current_trans_tok = inpaint_index[i, tokens[i] == mask_id]
all_tokens.append(motion_token)
all_tokens.append(first_current_trans_tok)
else:
if i < b-1:
first_current_trans_tok = inpaint_index[i, tokens[i] == mask_id]
all_tokens.append(motion_token)
all_tokens.append(first_current_trans_tok)
else:
all_tokens.append(motion_token)
all_tokens = torch.cat(all_tokens)
pred_pose_concat = self.vqvae(all_tokens.unsqueeze(0), type='decode')
trans_frame = num_transition_token*4
pred_pose = torch.zeros((b, 196, 263))
current_point = 0
for i in range(b):
if i == 0:
start_f = torch.tensor(0)
end_f = frame_length[i]
else:
start_f = current_point - trans_frame
end_f = start_f + frame_length[i]
current_point = end_f
pred_pose[i, :frame_length[i]] = pred_pose_concat[0, start_f: end_f]
return pred_pose
def upper_edit(self, pose, m_length, upper_text, lower_mask=None):
pose = pose.clone().float() # bs, nb_joints, joints_dim, seq_len
m_tokens_len = torch.ceil((m_length)/4)
bs, seq = pose.shape[:2]
max_motion_length = int(seq/4) + 1
mot_end_idx = self.vqvae.vqvae.num_code
mot_pad_idx = self.vqvae.vqvae.num_code + 1
mask_id = self.vqvae.vqvae.num_code + 2
target_lower = []
for k in range(bs):
target = self.vqvae(pose[k:k+1, :m_length[k]], type='encode')
if m_tokens_len[k]+1 < max_motion_length:
target = torch.cat([target,
torch.ones((1, 1, 2), dtype=int, device=target.device) * mot_end_idx,
torch.ones((1, max_motion_length-1-m_tokens_len[k].int().item(), 2), dtype=int, device=target.device) * mot_pad_idx], axis=1)
else:
target = torch.cat([target,
torch.ones((1, 1, 2), dtype=int, device=target.device) * mot_end_idx], axis=1)
target_lower.append(target[..., 1])
target_lower = torch.cat(target_lower, axis=0)
### lower mask ###
if lower_mask is not None:
lower_mask = torch.cat([lower_mask, torch.zeros(bs, 1, dtype=int)], dim=1).bool()
target_lower_masked = target_lower.clone()
target_lower_masked[lower_mask] = mask_id
select_end = target_lower == mot_end_idx
target_lower_masked[select_end] = target_lower[select_end]
else:
target_lower_masked = target_lower
##################
pred_len = m_length
pred_tok_len = m_tokens_len
pred_pose_eval = torch.zeros((bs, seq, pose.shape[-1]))
# __upper_text__ = ['A man punches with right hand.'] * 32
text = clip.tokenize(upper_text, truncate=True)
feat_clip_text, word_emb_clip = clip_model(text)
# index_motion = trans_encoder(feat_clip_text, idx_lower=target_lower_masked, word_emb=word_emb_clip, type="sample", m_length=pred_len, rand_pos=True, CFG=-1)
index_motion = self.maskdecoder(feat_clip_text, target_lower_masked, word_emb_clip, type="sample", m_length=pred_len, rand_pos=True)
for i in range(bs):
all_tokens = torch.cat([
index_motion[i:i+1, :int(pred_tok_len[i].item()), None],
target_lower[i:i+1, :int(pred_tok_len[i].item()), None]
], axis=-1)
pred_pose = self.vqvae(all_tokens, type='decode')
pred_pose_eval[i:i+1, :int(pred_len[i].item())] = pred_pose
return pred_pose_eval
if __name__ == '__main__':
args = option_trans.get_args_parser()
# python generate.py --resume-pth '/home/epinyoan/git/MaskText2Motion/T2M-BD/output/vq/2023-07-19-04-17-17_12_VQVAE_20batchResetNRandom_8192_32/net_last.pth' --resume-trans '/home/epinyoan/git/MaskText2Motion/T2M-BD/output/t2m/2023-10-12-10-11-15_HML3D_45_crsAtt1lyr_40breset_WRONG_THIS_20BRESET/net_last.pth' --text 'the person crouches and walks forward.' --length 156
mmm = MMM(args)
pred_pose = mmm([args.text], torch.tensor([args.length]), rand_pos=False)
std = np.load('./exit/t2m-std.npy')
mean = np.load('./exit/t2m-mean.npy')
file_name = '_'.join(args.text.split(' '))+'_'+str(args.length)
visualize_2motions(pred_pose[0].detach().cpu().numpy(), std, mean, 't2m', args.length, save_path='./output/'+file_name+'.html')