|
import os |
|
import torch |
|
import numpy as np |
|
|
|
from torch.utils.tensorboard import SummaryWriter |
|
from os.path import join as pjoin |
|
from torch.distributions import Categorical |
|
import json |
|
import clip |
|
|
|
import options.option_transformer as option_trans |
|
import models.vqvae as vqvae |
|
import utils.utils_model as utils_model |
|
import utils.eval_trans as eval_trans |
|
from dataset import dataset_TM_train |
|
from dataset import dataset_TM_eval |
|
from dataset import dataset_tokenize |
|
import models.t2m_trans as trans |
|
from options.get_eval_option import get_opt |
|
from models.evaluator_wrapper import EvaluatorModelWrapper |
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
from exit.utils import get_model, visualize_2motions |
|
from tqdm import tqdm |
|
from exit.utils import get_model, visualize_2motions, generate_src_mask, init_save_folder, uniform, cosine_schedule |
|
from einops import rearrange, repeat |
|
import torch.nn.functional as F |
|
from exit.utils import base_dir |
|
|
|
|
|
args = option_trans.get_args_parser() |
|
torch.manual_seed(args.seed) |
|
|
|
|
|
init_save_folder(args) |
|
|
|
|
|
args.vq_dir = f'./output/vq/{args.vq_name}' |
|
codebook_dir = f'{args.vq_dir}/codebook/' |
|
args.resume_pth = f'{args.vq_dir}/net_last.pth' |
|
os.makedirs(args.vq_dir, exist_ok = True) |
|
os.makedirs(codebook_dir, exist_ok = True) |
|
os.makedirs(args.out_dir, exist_ok = True) |
|
os.makedirs(args.out_dir+'/html', exist_ok=True) |
|
|
|
|
|
logger = utils_model.get_logger(args.out_dir) |
|
writer = SummaryWriter(args.out_dir) |
|
logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) |
|
|
|
|
|
from utils.word_vectorizer import WordVectorizer |
|
w_vectorizer = WordVectorizer('./glove', 'our_vab') |
|
val_loader = dataset_TM_eval.DATALoader(args.dataname, False, 32, w_vectorizer) |
|
|
|
dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' if args.dataname == 'kit' else 'checkpoints/t2m/Comp_v6_KLD005/opt.txt' |
|
|
|
wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda')) |
|
eval_wrapper = EvaluatorModelWrapper(wrapper_opt) |
|
|
|
|
|
clip_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cuda'), jit=False) |
|
clip.model.convert_weights(clip_model) |
|
clip_model.eval() |
|
for p in clip_model.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
class TextCLIP(torch.nn.Module): |
|
def __init__(self, model) : |
|
super(TextCLIP, self).__init__() |
|
self.model = model |
|
|
|
def forward(self,text): |
|
with torch.no_grad(): |
|
word_emb = self.model.token_embedding(text).type(self.model.dtype) |
|
word_emb = word_emb + self.model.positional_embedding.type(self.model.dtype) |
|
word_emb = word_emb.permute(1, 0, 2) |
|
word_emb = self.model.transformer(word_emb) |
|
word_emb = self.model.ln_final(word_emb).permute(1, 0, 2).float() |
|
enctxt = self.model.encode_text(text).float() |
|
return enctxt, word_emb |
|
clip_model = TextCLIP(clip_model) |
|
|
|
net = vqvae.HumanVQVAE(args, |
|
args.nb_code, |
|
args.code_dim, |
|
args.output_emb_width, |
|
args.down_t, |
|
args.stride_t, |
|
args.width, |
|
args.depth, |
|
args.dilation_growth_rate) |
|
|
|
|
|
trans_encoder = trans.Text2Motion_Transformer(vqvae=net, |
|
num_vq=args.nb_code, |
|
embed_dim=args.embed_dim_gpt, |
|
clip_dim=args.clip_dim, |
|
block_size=args.block_size, |
|
num_layers=args.num_layers, |
|
num_local_layer=args.num_local_layer, |
|
n_head=args.n_head_gpt, |
|
drop_out_rate=args.drop_out_rate, |
|
fc_rate=args.ff_rate) |
|
|
|
print ('loading checkpoint from {}'.format(args.resume_pth)) |
|
ckpt = torch.load(args.resume_pth, map_location='cpu') |
|
net.load_state_dict(ckpt['net'], strict=True) |
|
net.eval() |
|
net.cuda() |
|
|
|
if args.resume_trans is not None: |
|
print ('loading transformer checkpoint from {}'.format(args.resume_trans)) |
|
ckpt = torch.load(args.resume_trans, map_location='cpu') |
|
trans_encoder.load_state_dict(ckpt['trans'], strict=True) |
|
trans_encoder.train() |
|
trans_encoder.cuda() |
|
trans_encoder = torch.nn.DataParallel(trans_encoder) |
|
|
|
|
|
optimizer = utils_model.initial_optim(args.decay_option, args.lr, args.weight_decay, trans_encoder, args.optimizer) |
|
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=args.gamma) |
|
|
|
|
|
loss_ce = torch.nn.CrossEntropyLoss(reduction='none') |
|
|
|
|
|
|
|
if len(os.listdir(codebook_dir)) == 0: |
|
train_loader_token = dataset_tokenize.DATALoader(args.dataname, 1, unit_length=2**args.down_t) |
|
for batch in train_loader_token: |
|
pose, name = batch |
|
bs, seq = pose.shape[0], pose.shape[1] |
|
|
|
pose = pose.cuda().float() |
|
target = net(pose, type='encode') |
|
target = target.cpu().numpy() |
|
np.save(pjoin(codebook_dir, name[0] +'.npy'), target) |
|
|
|
|
|
train_loader = dataset_TM_train.DATALoader(args.dataname, args.batch_size, args.nb_code, codebook_dir, unit_length=2**args.down_t) |
|
train_loader_iter = dataset_TM_train.cycle(train_loader) |
|
|
|
|
|
|
|
best_fid=1000 |
|
best_iter=0 |
|
best_div=100 |
|
best_top1=0 |
|
best_top2=0 |
|
best_top3=0 |
|
best_matching=100 |
|
|
|
|
|
def get_acc(cls_pred, target, mask): |
|
cls_pred = torch.masked_select(cls_pred, mask.unsqueeze(-1)).view(-1, cls_pred.shape[-1]) |
|
target_all = torch.masked_select(target, mask) |
|
probs = torch.softmax(cls_pred, dim=-1) |
|
_, cls_pred_index = torch.max(probs, dim=-1) |
|
right_num = (cls_pred_index == target_all).sum() |
|
return right_num*100/mask.sum() |
|
|
|
|
|
for nb_iter in tqdm(range(1, args.total_iter + 1), position=0, leave=True): |
|
batch = next(train_loader_iter) |
|
clip_text, m_tokens, m_tokens_len = batch |
|
m_tokens, m_tokens_len = m_tokens.cuda(), m_tokens_len.cuda() |
|
bs = m_tokens.shape[0] |
|
target = m_tokens |
|
target = target.cuda() |
|
batch_size, max_len = target.shape[:2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
text = clip.tokenize(clip_text, truncate=True).cuda() |
|
|
|
feat_clip_text, word_emb = clip_model(text) |
|
|
|
|
|
if args.pkeep == -1: |
|
proba = np.random.rand(1)[0] |
|
mask = torch.bernoulli(proba * torch.ones(target.shape, |
|
device=target.device)) |
|
else: |
|
mask = torch.bernoulli(args.pkeep * torch.ones(target.shape, |
|
device=target.device)) |
|
|
|
seq_mask_no_end = generate_src_mask(max_len, m_tokens_len) |
|
mask = torch.logical_or(mask, ~seq_mask_no_end).int() |
|
r_indices = torch.randint_like(target, args.nb_code) |
|
input_indices = mask*target+(1-mask)*r_indices |
|
|
|
|
|
mask_id = get_model(net).vqvae.num_code + 2 |
|
|
|
|
|
rand_mask_probs = torch.zeros(batch_size, device = m_tokens_len.device).float().uniform_(0.5, 1) |
|
|
|
num_token_masked = (m_tokens_len * rand_mask_probs).round().clamp(min = 1) |
|
seq_mask = generate_src_mask(max_len, m_tokens_len+1) |
|
batch_randperm = torch.rand((batch_size, max_len), device = target.device) - seq_mask_no_end.int() |
|
batch_randperm = batch_randperm.argsort(dim = -1) |
|
mask_token = batch_randperm < rearrange(num_token_masked, 'b -> b 1') |
|
|
|
|
|
masked_input_indices = torch.where(mask_token, mask_id, input_indices) |
|
|
|
att_txt = None |
|
cls_pred = trans_encoder(masked_input_indices, feat_clip_text, src_mask = seq_mask, att_txt=att_txt, word_emb=word_emb)[:, 1:] |
|
|
|
|
|
weights = seq_mask_no_end / (seq_mask_no_end.sum(-1).unsqueeze(-1) * seq_mask_no_end.shape[0]) |
|
cls_pred_seq_masked = cls_pred[seq_mask_no_end, :].view(-1, cls_pred.shape[-1]) |
|
target_seq_masked = target[seq_mask_no_end] |
|
weight_seq_masked = weights[seq_mask_no_end] |
|
loss_cls = F.cross_entropy(cls_pred_seq_masked, target_seq_masked, reduction = 'none') |
|
loss_cls = (loss_cls * weight_seq_masked).sum() |
|
|
|
|
|
optimizer.zero_grad() |
|
loss_cls.backward() |
|
optimizer.step() |
|
scheduler.step() |
|
|
|
if nb_iter % args.print_iter == 0 : |
|
probs_seq_masked = torch.softmax(cls_pred_seq_masked, dim=-1) |
|
_, cls_pred_seq_masked_index = torch.max(probs_seq_masked, dim=-1) |
|
target_seq_masked = torch.masked_select(target, seq_mask_no_end) |
|
right_seq_masked = (cls_pred_seq_masked_index == target_seq_masked).sum() |
|
|
|
writer.add_scalar('./Loss/all', loss_cls, nb_iter) |
|
writer.add_scalar('./ACC/every_token', right_seq_masked*100/seq_mask_no_end.sum(), nb_iter) |
|
|
|
|
|
no_mask_token = ~mask_token * seq_mask_no_end |
|
writer.add_scalar('./ACC/masked', get_acc(cls_pred, target, mask_token), nb_iter) |
|
writer.add_scalar('./ACC/no_masked', get_acc(cls_pred, target, no_mask_token), nb_iter) |
|
|
|
|
|
|
|
|
|
if nb_iter==0 or nb_iter % args.eval_iter == 0 or nb_iter == args.total_iter: |
|
num_repeat = 1 |
|
rand_pos = False |
|
if nb_iter == args.total_iter: |
|
num_repeat = -30 |
|
rand_pos = True |
|
val_loader = dataset_TM_eval.DATALoader(args.dataname, True, 32, w_vectorizer) |
|
pred_pose_eval, pose, m_length, clip_text, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, best_multi, writer, logger = eval_trans.evaluation_transformer(args.out_dir, val_loader, net, trans_encoder, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, clip_model=clip_model, eval_wrapper=eval_wrapper, dataname=args.dataname, num_repeat=num_repeat, rand_pos=rand_pos) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if nb_iter == args.total_iter: |
|
msg_final = f"Train. Iter {best_iter} : FID. {best_fid:.5f}, Diversity. {best_div:.4f}, TOP1. {best_top1:.4f}, TOP2. {best_top2:.4f}, TOP3. {best_top3:.4f}" |
|
logger.info(msg_final) |
|
break |