from typing import List, Tuple from datasets import load_dataset from torch.utils.data import Dataset # Not ideal to import this type here but it's needed for the transform function from torchtune.modules import Tokenizer CROSS_ENTROPY_IGNORE_IDX = -100 _PROMPT_TEMPLATE = { "prompt_input": ( "Below is an instruction that describes a task, paired with an input that provides further context. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" ), "prompt_no_input": ( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Response:\n" ), } class AlpacaDataset(Dataset): """ See torchtune.datasets.AlpacaDataset for the original implementation. This version supports custom dataset paths. """ def __init__( self, dataset_path: str, tokenizer: Tokenizer, train_on_input: bool = True, **kwargs ) -> None: self._data = load_dataset(dataset_path, split="train") self._tokenizer = tokenizer self.train_on_input = train_on_input def __len__(self): return len(self._data) def __getitem__(self, index: int) -> Tuple[List[int], List[int]]: sample = self._data[index] return self._transform( instruction=sample["instruction"], input=sample["input"], output=sample["output"], ) def _transform( self, instruction: str, input: str, output: str ) -> Tuple[List[int], List[int]]: """ Split a sample on ``response`` tag to create input and labels. Args: instruction (str): Instruction text. input (str): Input text. Can be an empty string. Determines the prompt generation template used. output (str): Response text. Returns: Tuple of encoded inputs and labels. """ prompt = self._generate_prompt(instruction, input) prompt_with_response = prompt + output # add bos always; LlamaTokenizer sets this to True by default and neither # alpaca-lora or the original authors change this encoded_prompt = self._tokenizer.encode( text=prompt, add_bos=True, add_eos=False ) encoded_prompt_with_response = self._tokenizer.encode( text=prompt_with_response, add_bos=True, add_eos=True ) labels = encoded_prompt_with_response.copy() if not self.train_on_input: labels[: len(encoded_prompt)] = [CROSS_ENTROPY_IGNORE_IDX] * len( encoded_prompt ) assert len(encoded_prompt_with_response) == len(labels) return encoded_prompt_with_response, labels def _generate_prompt(self, instruction: str, input: str) -> str: """ Generate prompt from instruction and input. Args: instruction (str): Instruction text. input (str): Input text. Returns: Prompt text. """ if input: prompt = _PROMPT_TEMPLATE["prompt_input"].format( instruction=instruction, input=input ) else: prompt = _PROMPT_TEMPLATE["prompt_no_input"].format(instruction=instruction) return prompt