import os import sys sys.path.append(os.getcwd()) from data_utils import torch_data from trainer.options import parse_args from trainer.config import load_JsonConfig from nets.init_model import init_model import torch import torch.utils.data as data import torch.optim as optim import numpy as np import random import logging import time import shutil def prn_obj(obj): print('\n'.join(['%s:%s' % item for item in obj.__dict__.items()])) class Trainer(): def __init__(self) -> None: parser = parse_args() self.args = parser.parse_args() self.config = load_JsonConfig(self.args.config_file) os.environ['smplx_npz_path']=self.config.smplx_npz_path os.environ['extra_joint_path']=self.config.extra_joint_path os.environ['j14_regressor_path']=self.config.j14_regressor_path # torch.set_default_dtype(torch.float64) # wandb_run = wandb.init(project=f's2g_sweep') # if self.args.use_wandb: # print('starting wandb sweep agent...') # wandb_key = 'e3d537403fce5c8a99893c2cbe20a8d49a79358d' # os.environ['WANDB_API_KEY'] = wandb_key # # default_config=dict(w_b=1,w_h=10) # wandb.init(config=default_config) # self.config.param.w_b=wandb.config.w_b # self.config.param.w_h=wandb.config.w_h # self.config.Train.epochs=30 # if self.args.use_wandb: # print('starting wandb sweep agent...') # wandb_key = 'e3d537403fce5c8a99893c2cbe20a8d49a79358d' # os.environ['WANDB_API_KEY'] = wandb_key # # wandb.init(config=self.args, project="s2g_sweep") # # wandb.config.update(self.args) # # self.config.param.w_b=self.args.w_b # self.config.param.w_h=self.args.w_h # self.config.Train.epochs=30 self.device = torch.device(self.args.gpu) torch.cuda.set_device(self.device) self.setup_seed(self.args.seed) self.set_train_dir() shutil.copy(self.args.config_file, self.train_dir) self.generator = init_model(self.config.Model.model_name, self.args, self.config) self.init_dataloader() self.start_epoch = 0 self.global_steps = 0 if self.args.resume: self.resume() # self.init_optimizer() def setup_seed(self, seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True def set_train_dir(self): time_stamp = time.strftime('%Y-%m-%d',time.localtime(time.time())) train_dir = os.path.join(os.getcwd(), self.args.save_dir, os.path.normpath( time_stamp + '-' + self.args.exp_name + '-' + self.config.Log.name)) # train_dir= os.path.join(os.getcwd(), self.args.save_dir, os.path.normpath(time_stamp+'-'+self.args.exp_name+'-'+time.strftime("%H:%M:%S"))) os.makedirs(train_dir, exist_ok=True) log_file=os.path.join(train_dir, 'train.log') fmt="%(asctime)s-%(lineno)d-%(message)s" logging.basicConfig( stream=sys.stdout, level=logging.INFO,format=fmt, datefmt='%m/%d %I:%M:%S %p' ) fh=logging.FileHandler(log_file) fh.setFormatter(logging.Formatter(fmt)) logging.getLogger().addHandler(fh) self.train_dir = train_dir def resume(self): print('resume from a previous ckpt') ckpt = torch.load(self.args.pretrained_pth) self.generator.load_state_dict(ckpt['generator']) self.start_epoch = ckpt['epoch'] self.global_steps = ckpt['global_steps'] self.generator.global_step = self.global_steps def init_dataloader(self): if 'freeMo' in self.config.Model.model_name: if self.config.Data.data_root.endswith('.csv'): raise NotImplementedError else: data_class = torch_data self.train_set = data_class( data_root=self.config.Data.data_root, speakers=self.args.speakers, split='train', limbscaling=self.config.Data.pose.augmentation, normalization=self.config.Data.pose.normalization, norm_method=self.config.Data.pose.norm_method, split_trans_zero=True, num_pre_frames=self.config.Data.pose.pre_pose_length, num_frames=self.config.Data.pose.generate_length, aud_feat_win_size=self.config.Data.aud.aud_feat_win_size, aud_feat_dim=self.config.Data.aud.aud_feat_dim, feat_method=self.config.Data.aud.feat_method, context_info=self.config.Data.aud.context_info ) if self.config.Data.pose.normalization: self.norm_stats = (self.train_set.data_mean, self.train_set.data_std) save_file = os.path.join(self.train_dir, 'norm_stats.npy') np.save(save_file, self.norm_stats, allow_pickle=True) self.train_set.get_dataset() self.trans_set = self.train_set.trans_dataset self.zero_set = self.train_set.zero_dataset self.trans_loader = data.DataLoader(self.trans_set, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True) self.zero_loader = data.DataLoader(self.zero_set, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True) elif 'smplx' in self.config.Model.model_name or 's2g' in self.config.Model.model_name: data_class = torch_data self.train_set = data_class( data_root=self.config.Data.data_root, speakers=self.args.speakers, split='train', limbscaling=self.config.Data.pose.augmentation, normalization=self.config.Data.pose.normalization, norm_method=self.config.Data.pose.norm_method, split_trans_zero=False, num_pre_frames=self.config.Data.pose.pre_pose_length, num_frames=self.config.Data.pose.generate_length, num_generate_length=self.config.Data.pose.generate_length, aud_feat_win_size=self.config.Data.aud.aud_feat_win_size, aud_feat_dim=self.config.Data.aud.aud_feat_dim, feat_method=self.config.Data.aud.feat_method, context_info=self.config.Data.aud.context_info, smplx=True, audio_sr=22000, convert_to_6d=self.config.Data.pose.convert_to_6d, expression=self.config.Data.pose.expression, config=self.config ) if self.config.Data.pose.normalization: self.norm_stats = (self.train_set.data_mean, self.train_set.data_std) save_file = os.path.join(self.train_dir, 'norm_stats.npy') np.save(save_file, self.norm_stats, allow_pickle=True) self.train_set.get_dataset() self.train_loader = data.DataLoader(self.train_set.all_dataset, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True) else: data_class = torch_data self.train_set = data_class( data_root=self.config.Data.data_root, speakers=self.args.speakers, split='train', limbscaling=self.config.Data.pose.augmentation, normalization=self.config.Data.pose.normalization, norm_method=self.config.Data.pose.norm_method, split_trans_zero=False, num_pre_frames=self.config.Data.pose.pre_pose_length, num_frames=self.config.Data.pose.generate_length, aud_feat_win_size=self.config.Data.aud.aud_feat_win_size, aud_feat_dim=self.config.Data.aud.aud_feat_dim, feat_method=self.config.Data.aud.feat_method, context_info=self.config.Data.aud.context_info ) if self.config.Data.pose.normalization: self.norm_stats = (self.train_set.data_mean, self.train_set.data_std) save_file = os.path.join(self.train_dir, 'norm_stats.npy') np.save(save_file, self.norm_stats, allow_pickle=True) self.train_set.get_dataset() self.train_loader = data.DataLoader(self.train_set.all_dataset, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True) def init_optimizer(self): pass def print_func(self, loss_dict, steps): info_str = ['global_steps:%d'%(self.global_steps)] info_str += ['%s:%.4f'%(key, loss_dict[key]/steps) for key in list(loss_dict.keys())] logging.info(','.join(info_str)) def save_model(self, epoch): # if 'vq' in self.config.Model.model_name: # state_dict = { # 'g_body': self.g_body.state_dict(), # 'g_hand': self.g_hand.state_dict(), # 'epoch': epoch, # 'global_steps': self.global_steps # } # else: state_dict = { 'generator': self.generator.state_dict(), 'epoch': epoch, 'global_steps': self.global_steps } save_name = os.path.join(self.train_dir, 'ckpt-%d.pth'%(epoch)) torch.save(state_dict, save_name) def train_epoch(self, epoch): epoch_loss_dict = {} #最好是追踪每个epoch的loss变换 epoch_steps = 0 if 'freeMo' in self.config.Model.model_name: for bat in zip(self.trans_loader, self.zero_loader): self.global_steps += 1 epoch_steps += 1 _, loss_dict = self.generator(bat) if epoch_loss_dict:#非空 for key in list(loss_dict.keys()): epoch_loss_dict[key] += loss_dict[key] else: for key in list(loss_dict.keys()): epoch_loss_dict[key] = loss_dict[key] if self.global_steps % self.config.Log.print_every == 0: self.print_func(epoch_loss_dict, epoch_steps) else: # self.config.Model.model_name==smplx_S2G for bat in self.train_loader: # if epoch_steps == 1000: # break self.global_steps += 1 epoch_steps += 1 bat['epoch'] = epoch _, loss_dict = self.generator(bat) if epoch_loss_dict:#非空 for key in list(loss_dict.keys()): epoch_loss_dict[key] += loss_dict[key] else: for key in list(loss_dict.keys()): epoch_loss_dict[key] = loss_dict[key] if self.global_steps % self.config.Log.print_every == 0: self.print_func(epoch_loss_dict, epoch_steps) def train(self): logging.info('start_training') self.total_loss_dict = {} for epoch in range(self.start_epoch, self.config.Train.epochs): logging.info('epoch:%d'%(epoch)) self.train_epoch(epoch) # self.generator.scheduler.step() # logging.info('learning rate:%d' % (self.generator.scheduler.get_lr()[0])) if (epoch+1)%self.config.Log.save_every == 0 or (epoch+1) == 30: self.save_model(epoch)