# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from typing import List, Tuple import torch import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset from datasets import load_dataset # Not ideal to import this type here but it's needed for the transform function from torchtune.modules import Tokenizer CROSS_ENTROPY_IGNORE_IDX = -100 DEFAULT = 0 INSTRUCTION = 1 INPUT = 2 RESPONSE = 3 class ColoringAlpacaDataset(Dataset): """ See torchtune.datasets.alpaca.AlpacaDataset for the original implementation. Constructor now takes in a dataset path directly. This implementation returns 3 lists representing the tokens, labels, and token colors (as opposed to just the tokens & labels from the original). """ def __init__( self, tokenizer: Tokenizer, dataset_path: str = "yahma/alpaca-cleaned", train_on_input: bool = True, **kwargs ) -> None: self._data = load_dataset(dataset_path, split="train") self._tokenizer = tokenizer self.train_on_input = train_on_input self.num_colors = 4 # matches the above usage of DEFAULT, INSTRUCTION, INPUT, RESPONSE def __len__(self): return len(self._data) def __getitem__(self, index: int) -> Tuple[List[int], List[int], List[int]]: sample = self._data[index] return self._transform( instruction=sample["instruction"], input=sample["input"], output=sample["output"], ) def _transform( self, instruction: str, input: str, output: str ) -> Tuple[List[int], List[int], List[int]]: """ Split a sample on ``response`` tag to create input and labels. Args: instruction (str): Instruction text. input (str): Input text. Can be an empty string. Determines the prompt generation template used. output (str): Response text. Returns: Tuple of encoded inputs, labels, token colors. """ prompt = self._generate_prompt(instruction, input) # First handle the prompt colors = [] tokenized = [] labels = [] is_first = True for token_type, text in prompt: tokenized_part = self._tokenizer.encode( text=text, add_bos=is_first, add_eos=False ) is_first = False tokenized += tokenized_part colors += [token_type] * len(tokenized_part) if not self.train_on_input: labels += [CROSS_ENTROPY_IGNORE_IDX] * len(tokenized_part) else: labels += tokenized_part # Now add the response tokens tokenized_part = self._tokenizer.encode( text=output, add_bos=False, add_eos=True ) tokenized += tokenized_part colors += [RESPONSE] * len(tokenized_part) labels += tokenized_part assert len(tokenized) == len(labels) assert len(tokenized) == len(colors) return tokenized, labels, colors def _generate_prompt(self, instruction: str, input: str) -> List[Tuple[(int, str)]]: """ Generate prompt from instruction and input. Args: instruction (str): Instruction text. input (str): Input text. Returns: List of (int, templated text) """ if input: return [ (DEFAULT, ( "Below is an instruction that describes a task, paired with an input that provides further context. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n" )), (INSTRUCTION, instruction), (DEFAULT, "\n\n### Input:\n"), (INPUT, input), (DEFAULT, "\n\n### Response:\n"), ] else: return [ (DEFAULT, ( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n" )), (INSTRUCTION, instruction), (DEFAULT, "\n\n### Response:\n"), ] # TokenPair is a pair (tuple) of three lists: tokenized text inputs, labels, colors. TokenPair = Tuple[List[int], List[int], List[int]] def padded_collate( batch: List[TokenPair], padding_idx: int = 0, ignore_idx: int = -100, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: input_ids = pad_sequence( [torch.tensor(x[0]) for x in batch], batch_first=True, padding_value=padding_idx, ) labels = pad_sequence( [torch.tensor(x[1]) for x in batch], batch_first=True, padding_value=ignore_idx, ) colors = pad_sequence( [torch.tensor(x[2]) for x in batch], batch_first=True, padding_value=ignore_idx, ) input_ids_seq_len = input_ids.shape[-1] labels_seq_len = labels.shape[-1] colors_seq_len = colors.shape[-1] assert input_ids_seq_len == labels_seq_len assert input_ids_seq_len == colors_seq_len return input_ids, labels, colors