|
import os |
|
import json |
|
|
|
import torch |
|
import torch.optim as optim |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
import models.vqvae as vqvae |
|
import utils.losses as losses |
|
import options.option_vq as option_vq |
|
import utils.utils_model as utils_model |
|
from dataset import dataset_VQ, dataset_TM_eval |
|
import utils.eval_trans as eval_trans |
|
from options.get_eval_option import get_opt |
|
from models.evaluator_wrapper import EvaluatorModelWrapper |
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
from utils.word_vectorizer import WordVectorizer |
|
from tqdm import tqdm |
|
from exit.utils import get_model, generate_src_mask, init_save_folder |
|
from models.vqvae_sep import VQVAE_SEP |
|
|
|
def update_lr_warm_up(optimizer, nb_iter, warm_up_iter, lr): |
|
|
|
current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1) |
|
for param_group in optimizer.param_groups: |
|
param_group["lr"] = current_lr |
|
|
|
return optimizer, current_lr |
|
|
|
|
|
args = option_vq.get_args_parser() |
|
torch.manual_seed(args.seed) |
|
|
|
args.out_dir = os.path.join(args.out_dir, f'vq') |
|
|
|
init_save_folder(args) |
|
|
|
|
|
logger = utils_model.get_logger(args.out_dir) |
|
writer = SummaryWriter(args.out_dir) |
|
logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) |
|
|
|
|
|
|
|
w_vectorizer = WordVectorizer('./glove', 'our_vab') |
|
|
|
if args.dataname == 'kit' : |
|
dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' |
|
args.nb_joints = 21 |
|
|
|
else : |
|
dataset_opt_path = 'checkpoints/t2m/Comp_v6_KLD005/opt.txt' |
|
args.nb_joints = 22 |
|
|
|
logger.info(f'Training on {args.dataname}, motions are with {args.nb_joints} joints') |
|
|
|
wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda')) |
|
eval_wrapper = EvaluatorModelWrapper(wrapper_opt) |
|
|
|
|
|
|
|
train_loader = dataset_VQ.DATALoader(args.dataname, |
|
args.batch_size, |
|
window_size=args.window_size, |
|
unit_length=2**args.down_t) |
|
|
|
train_loader_iter = dataset_VQ.cycle(train_loader) |
|
|
|
val_loader = dataset_TM_eval.DATALoader(args.dataname, False, |
|
32, |
|
w_vectorizer, |
|
unit_length=2**args.down_t) |
|
|
|
|
|
if args.sep_uplow: |
|
net = VQVAE_SEP(args, |
|
args.nb_code, |
|
args.code_dim, |
|
args.output_emb_width, |
|
args.down_t, |
|
args.stride_t, |
|
args.width, |
|
args.depth, |
|
args.dilation_growth_rate, |
|
args.vq_act, |
|
args.vq_norm, |
|
{'mean': torch.from_numpy(train_loader.dataset.mean).cuda().float(), |
|
'std': torch.from_numpy(train_loader.dataset.std).cuda().float()}, |
|
True) |
|
else: |
|
net = vqvae.HumanVQVAE(args, |
|
args.nb_code, |
|
args.code_dim, |
|
args.output_emb_width, |
|
args.down_t, |
|
args.stride_t, |
|
args.width, |
|
args.depth, |
|
args.dilation_growth_rate, |
|
args.vq_act, |
|
args.vq_norm) |
|
|
|
|
|
if args.resume_pth : |
|
logger.info('loading checkpoint from {}'.format(args.resume_pth)) |
|
ckpt = torch.load(args.resume_pth, map_location='cpu') |
|
net.load_state_dict(ckpt['net'], strict=True) |
|
net.train() |
|
net.cuda() |
|
|
|
|
|
optimizer = optim.AdamW(net.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay) |
|
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=args.gamma) |
|
|
|
|
|
Loss = losses.ReConsLoss(args.recons_loss, args.nb_joints) |
|
|
|
|
|
avg_recons, avg_perplexity, avg_commit = 0., 0., 0. |
|
|
|
for nb_iter in tqdm(range(1, args.warm_up_iter)): |
|
|
|
optimizer, current_lr = update_lr_warm_up(optimizer, nb_iter, args.warm_up_iter, args.lr) |
|
|
|
gt_motion = next(train_loader_iter) |
|
gt_motion = gt_motion.cuda().float() |
|
|
|
pred_motion, loss_commit, perplexity = net(gt_motion) |
|
loss_motion = Loss(pred_motion, gt_motion) |
|
loss_vel = Loss.forward_joint(pred_motion, gt_motion) |
|
|
|
loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
avg_recons += loss_motion.item() |
|
avg_perplexity += perplexity.item() |
|
avg_commit += loss_commit.item() |
|
|
|
if nb_iter % args.print_iter == 0 : |
|
avg_recons /= args.print_iter |
|
avg_perplexity /= args.print_iter |
|
avg_commit /= args.print_iter |
|
|
|
logger.info(f"Warmup. Iter {nb_iter} : lr {current_lr:.5f} \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}") |
|
|
|
avg_recons, avg_perplexity, avg_commit = 0., 0., 0. |
|
|
|
|
|
avg_recons, avg_perplexity, avg_commit = 0., 0., 0. |
|
best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, 0, best_fid=1000, best_iter=0, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, eval_wrapper=eval_wrapper) |
|
|
|
for nb_iter in tqdm(range(1, args.total_iter + 1)): |
|
|
|
gt_motion = next(train_loader_iter) |
|
gt_motion = gt_motion.cuda().float() |
|
|
|
if args.sep_uplow: |
|
pred_motion, loss_commit, perplexity = net(gt_motion, idx_noise=0) |
|
else: |
|
pred_motion, loss_commit, perplexity = net(gt_motion) |
|
loss_motion = Loss(pred_motion, gt_motion) |
|
loss_vel = Loss.forward_joint(pred_motion, gt_motion) |
|
|
|
loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
scheduler.step() |
|
|
|
avg_recons += loss_motion.item() |
|
avg_perplexity += perplexity.item() |
|
avg_commit += loss_commit.item() |
|
|
|
if nb_iter % args.print_iter == 0 : |
|
avg_recons /= args.print_iter |
|
avg_perplexity /= args.print_iter |
|
avg_commit /= args.print_iter |
|
|
|
writer.add_scalar('./Train/L1', avg_recons, nb_iter) |
|
writer.add_scalar('./Train/PPL', avg_perplexity, nb_iter) |
|
writer.add_scalar('./Train/Commit', avg_commit, nb_iter) |
|
|
|
logger.info(f"Train. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}") |
|
|
|
avg_recons, avg_perplexity, avg_commit = 0., 0., 0., |
|
|
|
if nb_iter % args.eval_iter==0 : |
|
best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, eval_wrapper=eval_wrapper) |
|
|