Spaces:
Runtime error
Runtime error
from header import * | |
from .samplers import DistributedBatchSampler | |
from .sft_dataset import * | |
''' | |
def get_tokenizer(model): | |
tokenizer = LlamaTokenizer.from_pretrained(model) | |
tokenizer.bos_token_id, tokenizer.eos_token_id = 1, 2 | |
tokenizer.pad_token = tokenizer.eos_token | |
return tokenizer | |
''' | |
def load_sft_dataset(args): | |
''' | |
tokenizer = get_tokenizer(args['model_path']) | |
dataset_name = args['models'][args['model']]['stage1_train_dataset'] # SupervisedDataset, str | |
data_path = args["data_path"] | |
data = globals()[dataset_name](data_path, tokenizer, args['max_length']) #SupervisedDataset | |
''' | |
data = SupervisedDataset(args['data_path'], args['image_root_path']) | |
sampler = torch.utils.data.RandomSampler(data) | |
world_size = torch.distributed.get_world_size() | |
rank = torch.distributed.get_rank() | |
batch_size = args['world_size'] * args['dschf'].config['train_micro_batch_size_per_gpu'] | |
batch_sampler = DistributedBatchSampler( | |
sampler, | |
batch_size, | |
True, | |
rank, | |
world_size | |
) | |
iter_ = DataLoader( | |
data, | |
batch_sampler=batch_sampler, | |
num_workers=1, | |
collate_fn=data.collate, | |
pin_memory=True | |
) | |
return data, iter_, sampler | |