File size: 3,496 Bytes
261dbc8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from typing import List, Tuple
from datasets import load_dataset
from torch.utils.data import Dataset
# Not ideal to import this type here but it's needed for the transform function
from torchtune.modules import Tokenizer
CROSS_ENTROPY_IGNORE_IDX = -100
_PROMPT_TEMPLATE = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:\n"
),
}
class AlpacaDataset(Dataset):
"""
See torchtune.datasets.AlpacaDataset for the original implementation.
This version supports custom dataset paths.
"""
def __init__(
self,
dataset_path: str,
tokenizer: Tokenizer,
train_on_input: bool = True,
**kwargs
) -> None:
self._data = load_dataset(dataset_path, split="train")
self._tokenizer = tokenizer
self.train_on_input = train_on_input
def __len__(self):
return len(self._data)
def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
sample = self._data[index]
return self._transform(
instruction=sample["instruction"],
input=sample["input"],
output=sample["output"],
)
def _transform(
self, instruction: str, input: str, output: str
) -> Tuple[List[int], List[int]]:
"""
Split a sample on ``response`` tag to create input and labels.
Args:
instruction (str): Instruction text.
input (str): Input text. Can be an empty string. Determines the prompt generation template
used.
output (str): Response text.
Returns:
Tuple of encoded inputs and labels.
"""
prompt = self._generate_prompt(instruction, input)
prompt_with_response = prompt + output
# add bos always; LlamaTokenizer sets this to True by default and neither
# alpaca-lora or the original authors change this
encoded_prompt = self._tokenizer.encode(
text=prompt, add_bos=True, add_eos=False
)
encoded_prompt_with_response = self._tokenizer.encode(
text=prompt_with_response, add_bos=True, add_eos=True
)
labels = encoded_prompt_with_response.copy()
if not self.train_on_input:
labels[: len(encoded_prompt)] = [CROSS_ENTROPY_IGNORE_IDX] * len(
encoded_prompt
)
assert len(encoded_prompt_with_response) == len(labels)
return encoded_prompt_with_response, labels
def _generate_prompt(self, instruction: str, input: str) -> str:
"""
Generate prompt from instruction and input.
Args:
instruction (str): Instruction text.
input (str): Input text.
Returns:
Prompt text.
"""
if input:
prompt = _PROMPT_TEMPLATE["prompt_input"].format(
instruction=instruction, input=input
)
else:
prompt = _PROMPT_TEMPLATE["prompt_no_input"].format(instruction=instruction)
return prompt
|