import torch from collections import defaultdict import torch.optim as optim # import tensorflow as tf from torch.utils.tensorboard import SummaryWriter from collections import OrderedDict from utils.utils import * from os.path import join as pjoin from utils.eval_t2m import evaluation_mask_transformer, evaluation_res_transformer from models.mask_transformer.tools import * from einops import rearrange, repeat def def_value(): return 0.0 class MaskTransformerTrainer: def __init__(self, args, t2m_transformer, vq_model): self.opt = args self.t2m_transformer = t2m_transformer self.vq_model = vq_model self.device = args.device self.vq_model.eval() if args.is_train: self.logger = SummaryWriter(args.log_dir) def update_lr_warm_up(self, nb_iter, warm_up_iter, lr): current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1) for param_group in self.opt_t2m_transformer.param_groups: param_group["lr"] = current_lr return current_lr def forward(self, batch_data): conds, motion, m_lens = batch_data motion = motion.detach().float().to(self.device) m_lens = m_lens.detach().long().to(self.device) # (b, n, q) code_idx, _ = self.vq_model.encode(motion) m_lens = m_lens // 4 conds = conds.to(self.device).float() if torch.is_tensor(conds) else conds # loss_dict = {} # self.pred_ids = [] # self.acc = [] _loss, _pred_ids, _acc = self.t2m_transformer(code_idx[..., 0], conds, m_lens) return _loss, _acc def update(self, batch_data): loss, acc = self.forward(batch_data) self.opt_t2m_transformer.zero_grad() loss.backward() self.opt_t2m_transformer.step() self.scheduler.step() return loss.item(), acc def save(self, file_name, ep, total_it): t2m_trans_state_dict = self.t2m_transformer.state_dict() clip_weights = [e for e in t2m_trans_state_dict.keys() if e.startswith('clip_model.')] for e in clip_weights: del t2m_trans_state_dict[e] state = { 't2m_transformer': t2m_trans_state_dict, 'opt_t2m_transformer': self.opt_t2m_transformer.state_dict(), 'scheduler':self.scheduler.state_dict(), 'ep': ep, 'total_it': total_it, } torch.save(state, file_name) def resume(self, model_dir): checkpoint = torch.load(model_dir, map_location=self.device) missing_keys, unexpected_keys = self.t2m_transformer.load_state_dict(checkpoint['t2m_transformer'], strict=False) assert len(unexpected_keys) == 0 assert all([k.startswith('clip_model.') for k in missing_keys]) try: self.opt_t2m_transformer.load_state_dict(checkpoint['opt_t2m_transformer']) # Optimizer self.scheduler.load_state_dict(checkpoint['scheduler']) # Scheduler except: print('Resume wo optimizer') return checkpoint['ep'], checkpoint['total_it'] def train(self, train_loader, val_loader, eval_val_loader, eval_wrapper, plot_eval): self.t2m_transformer.to(self.device) self.vq_model.to(self.device) self.opt_t2m_transformer = optim.AdamW(self.t2m_transformer.parameters(), betas=(0.9, 0.99), lr=self.opt.lr, weight_decay=1e-5) self.scheduler = optim.lr_scheduler.MultiStepLR(self.opt_t2m_transformer, milestones=self.opt.milestones, gamma=self.opt.gamma) epoch = 0 it = 0 if self.opt.is_continue: model_dir = pjoin(self.opt.model_dir, 'latest.tar') # TODO epoch, it = self.resume(model_dir) print("Load model epoch:%d iterations:%d"%(epoch, it)) start_time = time.time() total_iters = self.opt.max_epoch * len(train_loader) print(f'Total Epochs: {self.opt.max_epoch}, Total Iters: {total_iters}') print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_loader), len(val_loader))) logs = defaultdict(def_value, OrderedDict()) best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_mask_transformer( self.opt.save_root, eval_val_loader, self.t2m_transformer, self.vq_model, self.logger, epoch, best_fid=100, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, eval_wrapper=eval_wrapper, plot_func=plot_eval, save_ckpt=False, save_anim=False ) best_acc = 0. while epoch < self.opt.max_epoch: self.t2m_transformer.train() self.vq_model.eval() for i, batch in enumerate(train_loader): it += 1 if it < self.opt.warm_up_iter: self.update_lr_warm_up(it, self.opt.warm_up_iter, self.opt.lr) loss, acc = self.update(batch_data=batch) logs['loss'] += loss logs['acc'] += acc logs['lr'] += self.opt_t2m_transformer.param_groups[0]['lr'] if it % self.opt.log_every == 0: mean_loss = OrderedDict() # self.logger.add_scalar('val_loss', val_loss, it) # self.l for tag, value in logs.items(): self.logger.add_scalar('Train/%s'%tag, value / self.opt.log_every, it) mean_loss[tag] = value / self.opt.log_every logs = defaultdict(def_value, OrderedDict()) print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i) if it % self.opt.save_latest == 0: self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) epoch += 1 print('Validation time:') self.vq_model.eval() self.t2m_transformer.eval() val_loss = [] val_acc = [] with torch.no_grad(): for i, batch_data in enumerate(val_loader): loss, acc = self.forward(batch_data) val_loss.append(loss.item()) val_acc.append(acc) print(f"Validation loss:{np.mean(val_loss):.3f}, accuracy:{np.mean(val_acc):.3f}") self.logger.add_scalar('Val/loss', np.mean(val_loss), epoch) self.logger.add_scalar('Val/acc', np.mean(val_acc), epoch) if np.mean(val_acc) > best_acc: print(f"Improved accuracy from {best_acc:.02f} to {np.mean(val_acc)}!!!") self.save(pjoin(self.opt.model_dir, 'net_best_acc.tar'), epoch, it) best_acc = np.mean(val_acc) best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_mask_transformer( self.opt.save_root, eval_val_loader, self.t2m_transformer, self.vq_model, self.logger, epoch, best_fid=best_fid, best_div=best_div, best_top1=best_top1, best_top2=best_top2, best_top3=best_top3, best_matching=best_matching, eval_wrapper=eval_wrapper, plot_func=plot_eval, save_ckpt=True, save_anim=(epoch%self.opt.eval_every_e==0) ) class ResidualTransformerTrainer: def __init__(self, args, res_transformer, vq_model): self.opt = args self.res_transformer = res_transformer self.vq_model = vq_model self.device = args.device self.vq_model.eval() if args.is_train: self.logger = SummaryWriter(args.log_dir) # self.l1_criterion = torch.nn.SmoothL1Loss() def update_lr_warm_up(self, nb_iter, warm_up_iter, lr): current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1) for param_group in self.opt_res_transformer.param_groups: param_group["lr"] = current_lr return current_lr def forward(self, batch_data): conds, motion, m_lens = batch_data motion = motion.detach().float().to(self.device) m_lens = m_lens.detach().long().to(self.device) # (b, n, q), (q, b, n ,d) code_idx, all_codes = self.vq_model.encode(motion) m_lens = m_lens // 4 conds = conds.to(self.device).float() if torch.is_tensor(conds) else conds ce_loss, pred_ids, acc = self.res_transformer(code_idx, conds, m_lens) return ce_loss, acc def update(self, batch_data): loss, acc = self.forward(batch_data) self.opt_res_transformer.zero_grad() loss.backward() self.opt_res_transformer.step() self.scheduler.step() return loss.item(), acc def save(self, file_name, ep, total_it): res_trans_state_dict = self.res_transformer.state_dict() clip_weights = [e for e in res_trans_state_dict.keys() if e.startswith('clip_model.')] for e in clip_weights: del res_trans_state_dict[e] state = { 'res_transformer': res_trans_state_dict, 'opt_res_transformer': self.opt_res_transformer.state_dict(), 'scheduler':self.scheduler.state_dict(), 'ep': ep, 'total_it': total_it, } torch.save(state, file_name) def resume(self, model_dir): checkpoint = torch.load(model_dir, map_location=self.device) missing_keys, unexpected_keys = self.res_transformer.load_state_dict(checkpoint['res_transformer'], strict=False) assert len(unexpected_keys) == 0 assert all([k.startswith('clip_model.') for k in missing_keys]) try: self.opt_res_transformer.load_state_dict(checkpoint['opt_res_transformer']) # Optimizer self.scheduler.load_state_dict(checkpoint['scheduler']) # Scheduler except: print('Resume wo optimizer') return checkpoint['ep'], checkpoint['total_it'] def train(self, train_loader, val_loader, eval_val_loader, eval_wrapper, plot_eval): self.res_transformer.to(self.device) self.vq_model.to(self.device) self.opt_res_transformer = optim.AdamW(self.res_transformer.parameters(), betas=(0.9, 0.99), lr=self.opt.lr, weight_decay=1e-5) self.scheduler = optim.lr_scheduler.MultiStepLR(self.opt_res_transformer, milestones=self.opt.milestones, gamma=self.opt.gamma) epoch = 0 it = 0 if self.opt.is_continue: model_dir = pjoin(self.opt.model_dir, 'latest.tar') # TODO epoch, it = self.resume(model_dir) print("Load model epoch:%d iterations:%d"%(epoch, it)) start_time = time.time() total_iters = self.opt.max_epoch * len(train_loader) print(f'Total Epochs: {self.opt.max_epoch}, Total Iters: {total_iters}') print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_loader), len(val_loader))) logs = defaultdict(def_value, OrderedDict()) best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_res_transformer( self.opt.save_root, eval_val_loader, self.res_transformer, self.vq_model, self.logger, epoch, best_fid=100, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, eval_wrapper=eval_wrapper, plot_func=plot_eval, save_ckpt=False, save_anim=False ) best_loss = 100 best_acc = 0 while epoch < self.opt.max_epoch: self.res_transformer.train() self.vq_model.eval() for i, batch in enumerate(train_loader): it += 1 if it < self.opt.warm_up_iter: self.update_lr_warm_up(it, self.opt.warm_up_iter, self.opt.lr) loss, acc = self.update(batch_data=batch) logs['loss'] += loss logs["acc"] += acc logs['lr'] += self.opt_res_transformer.param_groups[0]['lr'] if it % self.opt.log_every == 0: mean_loss = OrderedDict() # self.logger.add_scalar('val_loss', val_loss, it) # self.l for tag, value in logs.items(): self.logger.add_scalar('Train/%s'%tag, value / self.opt.log_every, it) mean_loss[tag] = value / self.opt.log_every logs = defaultdict(def_value, OrderedDict()) print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i) if it % self.opt.save_latest == 0: self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) epoch += 1 self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) print('Validation time:') self.vq_model.eval() self.res_transformer.eval() val_loss = [] val_acc = [] with torch.no_grad(): for i, batch_data in enumerate(val_loader): loss, acc = self.forward(batch_data) val_loss.append(loss.item()) val_acc.append(acc) print(f"Validation loss:{np.mean(val_loss):.3f}, Accuracy:{np.mean(val_acc):.3f}") self.logger.add_scalar('Val/loss', np.mean(val_loss), epoch) self.logger.add_scalar('Val/acc', np.mean(val_acc), epoch) if np.mean(val_loss) < best_loss: print(f"Improved loss from {best_loss:.02f} to {np.mean(val_loss)}!!!") self.save(pjoin(self.opt.model_dir, 'net_best_loss.tar'), epoch, it) best_loss = np.mean(val_loss) if np.mean(val_acc) > best_acc: print(f"Improved acc from {best_acc:.02f} to {np.mean(val_acc)}!!!") # self.save(pjoin(self.opt.model_dir, 'net_best_loss.tar'), epoch, it) best_acc = np.mean(val_acc) best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_res_transformer( self.opt.save_root, eval_val_loader, self.res_transformer, self.vq_model, self.logger, epoch, best_fid=best_fid, best_div=best_div, best_top1=best_top1, best_top2=best_top2, best_top3=best_top3, best_matching=best_matching, eval_wrapper=eval_wrapper, plot_func=plot_eval, save_ckpt=True, save_anim=(epoch%self.opt.eval_every_e==0) )