from header import * from dataset import load_dataset from model import * from config import * def parser_args(): parser = argparse.ArgumentParser(description='train parameters') parser.add_argument('--model', type=str, default='nextgpt') parser.add_argument('--mode', type=str, default='train', help='train or test or validation') parser.add_argument('--local_rank', default=0, type=int) parser.add_argument('--save_path', type=str, default='../ckpt/delta_ckpt/nextgpt/7b_tiva_v0/') parser.add_argument('--log_path', type=str, default='../ckpt/delta_ckpt/nextgpt/7b_tiva_v0/log/') parser.add_argument('--assets_path', type=str, default='./assets/') # model configurations parser.add_argument('--max_length', type=int, default=512) # the maximum input sequence length for LLMs parser.add_argument('--stage', type=int, default=1) # the training stage parser.add_argument('--modality', type=list, default=['image', 'video', 'audio', 'text']) return parser.parse_args() def initialize_distributed(args): args['master_ip'] = os.getenv('MASTER_ADDR', 'localhost') args['master_port'] = os.getenv('MASTER_PORT', '6000') args['world_size'] = int(os.getenv('WORLD_SIZE', '1')) args['local_rank'] = int(os.getenv('RANK', '0')) % torch.cuda.device_count() device = args['local_rank'] % torch.cuda.device_count() torch.cuda.set_device(device) deepspeed.init_distributed(dist_backend='nccl') def set_random_seed(seed): if seed is not None and seed > 0: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def config_env(args): args['root_dir'] = '../' # args['mode'] = 'train' config = load_config(args) args.update(config) initialize_distributed(args) set_random_seed(args['seed']) def build_directory(path): if os.path.exists(path): pass else: # recursively construct directory os.makedirs(path, exist_ok=True) def main(**args): config_env(args) print(args) args['ds_config_path'] = f'dsconfig/stage_{args["stage"]}.json' dschf = HfDeepSpeedConfig(args['ds_config_path']) args['dschf'] = dschf build_directory(args['save_path']) build_directory(args['log_path']) if args['log_path']: logging.basicConfig( format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s', level=logging.DEBUG, filename=f'{args["log_path"]}/train_{time.asctime()}.log', filemode='w' ) train_data, train_iter, sampler = load_dataset(args, args['dataset_name_list']) train_num = max([_cur_dataset.__len__() for _cur_dataset in train_data.datasets.datasets]) * len(train_data.datasets.datasets) length = args['epochs'] * train_num // args['world_size'] // dschf.config[ 'train_micro_batch_size_per_gpu'] total_steps = args['epochs'] * train_num // dschf.config['train_batch_size'] args['total_steps'] = total_steps agent = load_model(args) torch.distributed.barrier() # begin to train pbar = tqdm(total=length) # maximum total number current_step = 0 for epoch_i in tqdm(range(args['epochs'])): # for train_iter in train_iter_list: for batch in train_iter: agent.train_model( batch, current_step=current_step, pbar=pbar ) current_step += 1 # if current_step % 2000 == 0: # torch.distributed.barrier() # agent.save_model(args['save_path'], current_step) # save at the end of the training torch.distributed.barrier() agent.save_model(args['save_path'], current_step) if __name__ == "__main__": args = parser_args() args = vars(args) main(**args)