fix packing so that concatenated sequences reset the attention
Browse files- src/axolotl/datasets.py +5 -0
 - tests/fixtures/alpaca/alpaca.json +12 -0
 - tests/test_packed_dataset.py +64 -0
 
    	
        src/axolotl/datasets.py
    CHANGED
    
    | 
         @@ -127,6 +127,11 @@ class ConstantLengthDataset(IterableDataset): 
     | 
|
| 127 | 
         
             
                                    input_ids = example["input_ids"]
         
     | 
| 128 | 
         
             
                                    attention_mask = example["attention_mask"]
         
     | 
| 129 | 
         
             
                                    labels = example["labels"]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 130 | 
         | 
| 131 | 
         
             
                                    if add_concat_token:
         
     | 
| 132 | 
         
             
                                        input_ids.append(self.concat_token_id)
         
     | 
| 
         | 
|
| 127 | 
         
             
                                    input_ids = example["input_ids"]
         
     | 
| 128 | 
         
             
                                    attention_mask = example["attention_mask"]
         
     | 
| 129 | 
         
             
                                    labels = example["labels"]
         
     | 
| 130 | 
         
            +
                                    if (
         
     | 
| 131 | 
         
            +
                                        buffer["input_ids"]
         
     | 
| 132 | 
         
            +
                                        and input_ids[0] == self.tokenizer.bos_token_id
         
     | 
| 133 | 
         
            +
                                    ):
         
     | 
| 134 | 
         
            +
                                        attention_mask[0] = 0
         
     | 
| 135 | 
         | 
| 136 | 
         
             
                                    if add_concat_token:
         
     | 
| 137 | 
         
             
                                        input_ids.append(self.concat_token_id)
         
     | 
    	
        tests/fixtures/alpaca/alpaca.json
    ADDED
    
    | 
         @@ -0,0 +1,12 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            [
         
     | 
| 2 | 
         
            +
              {
         
     | 
| 3 | 
         
            +
                "instruction": "You will be given a series of words. Output these words in reverse order, with each word on its own line.",
         
     | 
| 4 | 
         
            +
                "input": "Words: ['Hello', 'world'].",
         
     | 
| 5 | 
         
            +
                "output": "['world', 'Hello']"
         
     | 
| 6 | 
         
            +
              },
         
     | 
| 7 | 
         
            +
              {
         
     | 
| 8 | 
         
            +
                "instruction": "In this task, you're given a short description of an event. Your job is to order the steps involved in the event from first to last. Note that there may be multiple correct answers for each event.",
         
     | 
| 9 | 
         
            +
                "input": "Description: A man walks into a bar and orders a drink. He pays for his drink and leaves the bar.",
         
     | 
| 10 | 
         
            +
                "output": "1. The man walks into the bar.\n2. He orders a drink.\n3. He pays for his drink.\n4. He leaves the bar."
         
     | 
| 11 | 
         
            +
              }
         
     | 
| 12 | 
         
            +
            ]
         
     | 
    	
        tests/test_packed_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,64 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """Module for testing dataset sequence packing"""
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import unittest
         
     | 
| 4 | 
         
            +
            from pathlib import Path
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from datasets import Dataset, load_dataset
         
     | 
| 7 | 
         
            +
            from transformers import AutoTokenizer
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
         
     | 
| 10 | 
         
            +
            from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
         
     | 
| 11 | 
         
            +
            from axolotl.prompters import AlpacaPrompter
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            class TestPacking(unittest.TestCase):
         
     | 
| 15 | 
         
            +
                """
         
     | 
| 16 | 
         
            +
                Test class for packing dataset sequences
         
     | 
| 17 | 
         
            +
                """
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                def setUp(self) -> None:
         
     | 
| 20 | 
         
            +
                    self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
         
     | 
| 21 | 
         
            +
                    self.tokenizer.add_special_tokens(
         
     | 
| 22 | 
         
            +
                        {
         
     | 
| 23 | 
         
            +
                            "bos_token": "<s>",
         
     | 
| 24 | 
         
            +
                            "eos_token": "</s>",
         
     | 
| 25 | 
         
            +
                            "unk_token": "<unk>",
         
     | 
| 26 | 
         
            +
                        }
         
     | 
| 27 | 
         
            +
                    )
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                def test_resets_attention(self):
         
     | 
| 30 | 
         
            +
                    prompter = AlpacaPrompter("chat")
         
     | 
| 31 | 
         
            +
                    strat = AlpacaPromptTokenizingStrategy(
         
     | 
| 32 | 
         
            +
                        prompter,
         
     | 
| 33 | 
         
            +
                        self.tokenizer,
         
     | 
| 34 | 
         
            +
                        False,
         
     | 
| 35 | 
         
            +
                        2048,
         
     | 
| 36 | 
         
            +
                    )
         
     | 
| 37 | 
         
            +
                    dateset = load_dataset(
         
     | 
| 38 | 
         
            +
                        "json",
         
     | 
| 39 | 
         
            +
                        data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
         
     | 
| 40 | 
         
            +
                    )["train"]
         
     | 
| 41 | 
         
            +
                    dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                    constant_len_dataset = ConstantLengthDataset(
         
     | 
| 44 | 
         
            +
                        self.tokenizer,
         
     | 
| 45 | 
         
            +
                        [dataset],
         
     | 
| 46 | 
         
            +
                        seq_length=2048,
         
     | 
| 47 | 
         
            +
                    )
         
     | 
| 48 | 
         
            +
                    packed_dataset = Dataset.from_list(list(constant_len_dataset))
         
     | 
| 49 | 
         
            +
                    example = packed_dataset[0]
         
     | 
| 50 | 
         
            +
                    next_bos_index = (
         
     | 
| 51 | 
         
            +
                        example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
         
     | 
| 52 | 
         
            +
                    )  # add one since we sliced
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    # first example doesn't have mask reset
         
     | 
| 55 | 
         
            +
                    assert example["input_ids"][0] == self.tokenizer.bos_token_id
         
     | 
| 56 | 
         
            +
                    assert example["attention_mask"][0] == 1
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    # but subsequent one does
         
     | 
| 59 | 
         
            +
                    assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
         
     | 
| 60 | 
         
            +
                    assert example["attention_mask"][next_bos_index] == 0
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 64 | 
         
            +
                unittest.main()
         
     |