from header import * from .samplers import DistributedBatchSampler from .sft_dataset import * ''' def get_tokenizer(model): tokenizer = LlamaTokenizer.from_pretrained(model) tokenizer.bos_token_id, tokenizer.eos_token_id = 1, 2 tokenizer.pad_token = tokenizer.eos_token return tokenizer ''' def load_sft_dataset(args): ''' tokenizer = get_tokenizer(args['model_path']) dataset_name = args['models'][args['model']]['stage1_train_dataset'] # SupervisedDataset, str data_path = args["data_path"] data = globals()[dataset_name](data_path, tokenizer, args['max_length']) #SupervisedDataset ''' data = SupervisedDataset(args['data_path'], args['image_root_path']) sampler = torch.utils.data.RandomSampler(data) world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() batch_size = args['world_size'] * args['dschf'].config['train_micro_batch_size_per_gpu'] batch_sampler = DistributedBatchSampler( sampler, batch_size, True, rank, world_size ) iter_ = DataLoader( data, batch_sampler=batch_sampler, num_workers=1, collate_fn=data.collate, pin_memory=True ) return data, iter_, sampler