File size: 3,496 Bytes
261dbc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

from typing import List, Tuple

from datasets import load_dataset
from torch.utils.data import Dataset

# Not ideal to import this type here but it's needed for the transform function
from torchtune.modules import Tokenizer


CROSS_ENTROPY_IGNORE_IDX = -100

_PROMPT_TEMPLATE = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:\n"
    ),
}


class AlpacaDataset(Dataset):
    """
    See torchtune.datasets.AlpacaDataset for the original implementation.
    This version supports custom dataset paths.
    """

    def __init__(
        self,
        dataset_path: str,
        tokenizer: Tokenizer,
        train_on_input: bool = True,
        **kwargs
    ) -> None:
        self._data = load_dataset(dataset_path, split="train")
        self._tokenizer = tokenizer
        self.train_on_input = train_on_input

    def __len__(self):
        return len(self._data)

    def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
        sample = self._data[index]

        return self._transform(
            instruction=sample["instruction"],
            input=sample["input"],
            output=sample["output"],
        )

    def _transform(
        self, instruction: str, input: str, output: str
    ) -> Tuple[List[int], List[int]]:
        """
        Split a sample on ``response`` tag to create input and labels.

        Args:
            instruction (str): Instruction text.
            input (str): Input text. Can be an empty string. Determines the prompt generation template
                used.
            output (str): Response text.

        Returns:
            Tuple of encoded inputs and labels.
        """
        prompt = self._generate_prompt(instruction, input)
        prompt_with_response = prompt + output

        # add bos always; LlamaTokenizer sets this to True by default and neither
        # alpaca-lora or the original authors change this
        encoded_prompt = self._tokenizer.encode(
            text=prompt, add_bos=True, add_eos=False
        )
        encoded_prompt_with_response = self._tokenizer.encode(
            text=prompt_with_response, add_bos=True, add_eos=True
        )
        labels = encoded_prompt_with_response.copy()

        if not self.train_on_input:
            labels[: len(encoded_prompt)] = [CROSS_ENTROPY_IGNORE_IDX] * len(
                encoded_prompt
            )

        assert len(encoded_prompt_with_response) == len(labels)

        return encoded_prompt_with_response, labels

    def _generate_prompt(self, instruction: str, input: str) -> str:
        """
        Generate prompt from instruction and input.

        Args:
            instruction (str): Instruction text.
            input (str): Input text.

        Returns:
            Prompt text.
        """
        if input:
            prompt = _PROMPT_TEMPLATE["prompt_input"].format(
                instruction=instruction, input=input
            )
        else:
            prompt = _PROMPT_TEMPLATE["prompt_no_input"].format(instruction=instruction)
        return prompt