|
import copy |
|
|
|
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX |
|
|
|
def get_bos_eos_token_ids(tokenizer): |
|
if tokenizer.__class__.__name__ in [ |
|
'QWenTokenizer', 'QWen2Tokenizer', 'Qwen2TokenizerFast' |
|
]: |
|
bos_token_id = [] |
|
eos_token_id = tokenizer.eos_token_id |
|
assert eos_token_id is not None, \ |
|
'Please set eos_token for Qwen tokenizer!' |
|
elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer': |
|
bos_token_id = [64790, 64792] |
|
eos_token_id = tokenizer.eos_token_id |
|
else: |
|
bos_token_id = tokenizer.bos_token_id |
|
eos_token_id = tokenizer.eos_token_id |
|
if isinstance(bos_token_id, int): |
|
bos_token_id = [bos_token_id] |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
return bos_token_id, eos_token_id |
|
|
|
def encode_fn(example, |
|
tokenizer, |
|
max_length, |
|
input_ids_with_output=True, |
|
with_image_token=False): |
|
"""We only support the following three scenarios: |
|
|
|
1. Incremental pretraining dataset. |
|
example['conversation'] = [ |
|
{ |
|
'input': '', |
|
'output': '### Human: Can you write xxx' |
|
} |
|
] |
|
|
|
2. Single-turn conversation dataset. |
|
example['conversation'] = [ |
|
{ |
|
'input': 'Give three tips for staying healthy.', |
|
'output': '1.Eat a balanced diet xxx' |
|
} |
|
] |
|
|
|
3. Multi-turn conversation dataset. |
|
example['conversation'] = [ |
|
{ |
|
'input': 'Give three tips for staying healthy.', |
|
'output': '1.Eat a balanced diet xxx' |
|
}, |
|
{ |
|
'input': 'Please expand on the second point.', |
|
'output': 'Here is an expanded explanation of the xxx' |
|
} |
|
] |
|
""" |
|
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer) |
|
is_multi_turn_conversation = len(example['conversation']) > 1 |
|
if is_multi_turn_conversation: |
|
assert input_ids_with_output |
|
|
|
input_ids, labels = [], [] |
|
next_needs_bos_token = True |
|
for single_turn_conversation in example['conversation']: |
|
input = single_turn_conversation['input'] |
|
if DEFAULT_IMAGE_TOKEN in input and with_image_token: |
|
chunk_encode = [ |
|
tokenizer.encode(chunk, add_special_tokens=False) |
|
for chunk in input.split(DEFAULT_IMAGE_TOKEN) |
|
] |
|
input_encode = [] |
|
for idx, cur_chunk_encode in enumerate(chunk_encode): |
|
input_encode.extend(cur_chunk_encode) |
|
if idx != len(chunk_encode) - 1: |
|
input_encode.append(IMAGE_TOKEN_INDEX) |
|
else: |
|
input_encode = tokenizer.encode(input, add_special_tokens=False) |
|
if next_needs_bos_token: |
|
input_ids += bos_token_id |
|
labels += [IGNORE_INDEX] * len(bos_token_id) |
|
input_ids += input_encode |
|
labels += [IGNORE_INDEX] * len(input_encode) |
|
if input_ids_with_output: |
|
|
|
output_with_loss = single_turn_conversation.get( |
|
'output_with_loss', True) |
|
output = single_turn_conversation['output'] |
|
output_encode = tokenizer.encode(output, add_special_tokens=False) |
|
input_ids += output_encode |
|
if output_with_loss: |
|
labels += copy.deepcopy(output_encode) |
|
else: |
|
labels += [IGNORE_INDEX] * len(output_encode) |
|
|
|
if single_turn_conversation.get('need_eos_token', True): |
|
next_needs_bos_token = True |
|
input_ids += eos_token_id |
|
if output_with_loss: |
|
labels += copy.deepcopy(eos_token_id) |
|
else: |
|
labels += [IGNORE_INDEX] * len(eos_token_id) |
|
else: |
|
next_needs_bos_token = False |
|
|
|
sep = single_turn_conversation.get('sep', '') |
|
if sep != '': |
|
sep_encode = tokenizer.encode(sep, add_special_tokens=False) |
|
input_ids += sep_encode |
|
labels += [IGNORE_INDEX] * len(sep_encode) |
|
|
|
if len(input_ids) > max_length: |
|
input_ids = input_ids[:max_length] |
|
labels = labels[:max_length] |
|
return {'input_ids': input_ids, 'labels': labels} |