PandaGPT / datasets /__init__.py
gmftbyGMFTBY
update
8366b03
raw
history blame
1.27 kB
from header import *
from .samplers import DistributedBatchSampler
from .sft_dataset import *
'''
def get_tokenizer(model):
tokenizer = LlamaTokenizer.from_pretrained(model)
tokenizer.bos_token_id, tokenizer.eos_token_id = 1, 2
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
'''
def load_sft_dataset(args):
'''
tokenizer = get_tokenizer(args['model_path'])
dataset_name = args['models'][args['model']]['stage1_train_dataset'] # SupervisedDataset, str
data_path = args["data_path"]
data = globals()[dataset_name](data_path, tokenizer, args['max_length']) #SupervisedDataset
'''
data = SupervisedDataset(args['data_path'], args['image_root_path'])
sampler = torch.utils.data.RandomSampler(data)
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
batch_size = args['world_size'] * args['dschf'].config['train_micro_batch_size_per_gpu']
batch_sampler = DistributedBatchSampler(
sampler,
batch_size,
True,
rank,
world_size
)
iter_ = DataLoader(
data,
batch_sampler=batch_sampler,
num_workers=1,
collate_fn=data.collate,
pin_memory=True
)
return data, iter_, sampler