|
""" |
|
Test module for sharegpt integration w chatml |
|
""" |
|
|
|
import pytest |
|
from datasets import Dataset |
|
from tokenizers import AddedToken |
|
from transformers import AutoTokenizer |
|
|
|
from axolotl.datasets import TokenizedPromptDataset |
|
from axolotl.prompt_strategies.sharegpt import ( |
|
GlaiveShareGPTPromptTokenizingStrategy, |
|
SimpleShareGPTPromptTokenizingStrategy, |
|
register_chatml_template, |
|
) |
|
from axolotl.prompters import ShareGPTPrompterV2 |
|
|
|
register_chatml_template() |
|
|
|
|
|
@pytest.fixture(name="sharegpt_dataset") |
|
def fixture_sharegpt_dataset(): |
|
return Dataset.from_list( |
|
[ |
|
{ |
|
"conversations": [ |
|
{ |
|
"from": "system", |
|
"value": "repeat", |
|
}, |
|
{ |
|
"from": "human", |
|
"value": "hello", |
|
}, |
|
{ |
|
"from": "gpt", |
|
"value": "hello", |
|
}, |
|
{ |
|
"from": "human", |
|
"value": "goodbye", |
|
}, |
|
{ |
|
"from": "gpt", |
|
"value": "goodbye", |
|
}, |
|
] |
|
} |
|
] |
|
) |
|
|
|
|
|
@pytest.fixture(name="glaive_dataset") |
|
def fixture_sharegpt_glaive_dataset(): |
|
return Dataset.from_list( |
|
[ |
|
{ |
|
"system": "SYSTEM: This is a system prompt", |
|
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>", |
|
} |
|
] |
|
) |
|
|
|
|
|
@pytest.fixture(name="tokenizer") |
|
def fixture_tokenizer(): |
|
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") |
|
tokenizer.add_special_tokens( |
|
{ |
|
"eos_token": AddedToken( |
|
"<|im_end|>", rstrip=False, lstrip=False, normalized=False |
|
) |
|
} |
|
) |
|
tokenizer.add_tokens( |
|
[ |
|
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False), |
|
] |
|
) |
|
|
|
return tokenizer |
|
|
|
|
|
class TestSharegpt: |
|
""" |
|
Test class for sharegpt prompter |
|
""" |
|
|
|
def test_no_double_im_end(self, sharegpt_dataset, tokenizer): |
|
strategy = SimpleShareGPTPromptTokenizingStrategy( |
|
ShareGPTPrompterV2( |
|
conversation="chatml", |
|
role_key_model=None, |
|
role_key_human=None, |
|
), |
|
tokenizer, |
|
False, |
|
2048, |
|
) |
|
|
|
dataset_wrapper = TokenizedPromptDataset( |
|
strategy, sharegpt_dataset, process_count=1 |
|
) |
|
|
|
input_ids = dataset_wrapper[0]["input_ids"] |
|
|
|
assert input_ids == [ |
|
|
|
1, |
|
32001, 1587, 13, 25997, 32000, 28705, 13, |
|
32001, 2188, 13, 21558, 32000, 28705, 13, |
|
32001, 13892, 13, 21558, 32000, 28705, 13, |
|
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, |
|
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, |
|
] |
|
|
|
|
|
def test_w_train_on_input(self, sharegpt_dataset, tokenizer): |
|
strategy = SimpleShareGPTPromptTokenizingStrategy( |
|
ShareGPTPrompterV2( |
|
conversation="chatml", |
|
role_key_model=None, |
|
role_key_human=None, |
|
), |
|
tokenizer, |
|
False, |
|
2048, |
|
) |
|
|
|
dataset_wrapper = TokenizedPromptDataset( |
|
strategy, sharegpt_dataset, process_count=1 |
|
) |
|
|
|
labels = dataset_wrapper[0]["labels"] |
|
|
|
assert labels == [ |
|
-100, |
|
-100, -100, -100, -100, -100, -100, -100, |
|
-100, -100, -100, -100, -100, -100, -100, |
|
-100, -100, 13, 21558, 32000, 28705, 13, |
|
-100, -100, -100, -100, -100, -100, -100, -100, |
|
-100, -100, 13, 12684, 17664, 32000, 28705, 13, |
|
] |
|
|
|
|
|
def test_no_train_on_input(self, sharegpt_dataset, tokenizer): |
|
strategy = SimpleShareGPTPromptTokenizingStrategy( |
|
ShareGPTPrompterV2( |
|
conversation="chatml", |
|
role_key_model=None, |
|
role_key_human=None, |
|
), |
|
tokenizer, |
|
True, |
|
2048, |
|
) |
|
|
|
dataset_wrapper = TokenizedPromptDataset( |
|
strategy, sharegpt_dataset, process_count=1 |
|
) |
|
|
|
labels = dataset_wrapper[0]["labels"] |
|
|
|
assert labels == [ |
|
1, |
|
32001, 1587, 13, 25997, 32000, 28705, 13, |
|
32001, 2188, 13, 21558, 32000, 28705, 13, |
|
32001, 13892, 13, 21558, 32000, 28705, 13, |
|
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, |
|
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, |
|
] |
|
|
|
|
|
def test_chatml_glaive(self, glaive_dataset, tokenizer): |
|
strategy = GlaiveShareGPTPromptTokenizingStrategy( |
|
ShareGPTPrompterV2( |
|
conversation="chatml", |
|
role_key_model=None, |
|
role_key_human=None, |
|
), |
|
tokenizer, |
|
True, |
|
2048, |
|
) |
|
|
|
dataset_wrapper = TokenizedPromptDataset( |
|
strategy, glaive_dataset, process_count=1 |
|
) |
|
|
|
labels = dataset_wrapper[0]["labels"] |
|
|
|
assert labels == [ |
|
1, |
|
32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, |
|
32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, |
|
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 |
|
] |
|
|
|
|