laurencer's picture
Step 6000
261dbc8 verified
raw
history blame contribute delete
No virus
3.5 kB
from typing import List, Tuple
from datasets import load_dataset
from torch.utils.data import Dataset
# Not ideal to import this type here but it's needed for the transform function
from torchtune.modules import Tokenizer
CROSS_ENTROPY_IGNORE_IDX = -100
_PROMPT_TEMPLATE = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:\n"
),
}
class AlpacaDataset(Dataset):
"""
See torchtune.datasets.AlpacaDataset for the original implementation.
This version supports custom dataset paths.
"""
def __init__(
self,
dataset_path: str,
tokenizer: Tokenizer,
train_on_input: bool = True,
**kwargs
) -> None:
self._data = load_dataset(dataset_path, split="train")
self._tokenizer = tokenizer
self.train_on_input = train_on_input
def __len__(self):
return len(self._data)
def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
sample = self._data[index]
return self._transform(
instruction=sample["instruction"],
input=sample["input"],
output=sample["output"],
)
def _transform(
self, instruction: str, input: str, output: str
) -> Tuple[List[int], List[int]]:
"""
Split a sample on ``response`` tag to create input and labels.
Args:
instruction (str): Instruction text.
input (str): Input text. Can be an empty string. Determines the prompt generation template
used.
output (str): Response text.
Returns:
Tuple of encoded inputs and labels.
"""
prompt = self._generate_prompt(instruction, input)
prompt_with_response = prompt + output
# add bos always; LlamaTokenizer sets this to True by default and neither
# alpaca-lora or the original authors change this
encoded_prompt = self._tokenizer.encode(
text=prompt, add_bos=True, add_eos=False
)
encoded_prompt_with_response = self._tokenizer.encode(
text=prompt_with_response, add_bos=True, add_eos=True
)
labels = encoded_prompt_with_response.copy()
if not self.train_on_input:
labels[: len(encoded_prompt)] = [CROSS_ENTROPY_IGNORE_IDX] * len(
encoded_prompt
)
assert len(encoded_prompt_with_response) == len(labels)
return encoded_prompt_with_response, labels
def _generate_prompt(self, instruction: str, input: str) -> str:
"""
Generate prompt from instruction and input.
Args:
instruction (str): Instruction text.
input (str): Input text.
Returns:
Prompt text.
"""
if input:
prompt = _PROMPT_TEMPLATE["prompt_input"].format(
instruction=instruction, input=input
)
else:
prompt = _PROMPT_TEMPLATE["prompt_no_input"].format(instruction=instruction)
return prompt