zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import copy
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
def get_bos_eos_token_ids(tokenizer):
if tokenizer.__class__.__name__ in [
'QWenTokenizer', 'QWen2Tokenizer', 'Qwen2TokenizerFast'
]:
bos_token_id = []
eos_token_id = tokenizer.eos_token_id
assert eos_token_id is not None, \
'Please set eos_token for Qwen tokenizer!'
elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer':
bos_token_id = [64790, 64792]
eos_token_id = tokenizer.eos_token_id
else:
bos_token_id = tokenizer.bos_token_id
eos_token_id = tokenizer.eos_token_id
if isinstance(bos_token_id, int):
bos_token_id = [bos_token_id]
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
return bos_token_id, eos_token_id
def encode_fn(example,
tokenizer,
max_length,
input_ids_with_output=True,
with_image_token=False):
"""We only support the following three scenarios:
1. Incremental pretraining dataset.
example['conversation'] = [
{
'input': '',
'output': '### Human: Can you write xxx'
}
]
2. Single-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
}
]
3. Multi-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
},
{
'input': 'Please expand on the second point.',
'output': 'Here is an expanded explanation of the xxx'
}
]
"""
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
is_multi_turn_conversation = len(example['conversation']) > 1
if is_multi_turn_conversation:
assert input_ids_with_output
input_ids, labels = [], []
next_needs_bos_token = True
for single_turn_conversation in example['conversation']:
input = single_turn_conversation['input']
if DEFAULT_IMAGE_TOKEN in input and with_image_token:
chunk_encode = [
tokenizer.encode(chunk, add_special_tokens=False)
for chunk in input.split(DEFAULT_IMAGE_TOKEN)
]
input_encode = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
input_encode.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
input_encode.append(IMAGE_TOKEN_INDEX)
else:
input_encode = tokenizer.encode(input, add_special_tokens=False)
if next_needs_bos_token:
input_ids += bos_token_id
labels += [IGNORE_INDEX] * len(bos_token_id)
input_ids += input_encode
labels += [IGNORE_INDEX] * len(input_encode)
if input_ids_with_output:
# Add output
output_with_loss = single_turn_conversation.get(
'output_with_loss', True)
output = single_turn_conversation['output']
output_encode = tokenizer.encode(output, add_special_tokens=False)
input_ids += output_encode
if output_with_loss:
labels += copy.deepcopy(output_encode)
else:
labels += [IGNORE_INDEX] * len(output_encode)
# Add EOS_TOKEN (with loss)
if single_turn_conversation.get('need_eos_token', True):
next_needs_bos_token = True
input_ids += eos_token_id
if output_with_loss:
labels += copy.deepcopy(eos_token_id)
else:
labels += [IGNORE_INDEX] * len(eos_token_id)
else:
next_needs_bos_token = False
# Add SEP (without loss)
sep = single_turn_conversation.get('sep', '')
if sep != '':
sep_encode = tokenizer.encode(sep, add_special_tokens=False)
input_ids += sep_encode
labels += [IGNORE_INDEX] * len(sep_encode)
if len(input_ids) > max_length:
input_ids = input_ids[:max_length]
labels = labels[:max_length]
return {'input_ids': input_ids, 'labels': labels}