|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Tuple |
|
|
|
import torch |
|
|
|
import torch.nn.functional as F |
|
from torch.nn.utils.rnn import pad_sequence |
|
from torch.utils.data import Dataset |
|
|
|
from datasets import load_dataset |
|
|
|
|
|
from torchtune.modules import Tokenizer |
|
|
|
|
|
CROSS_ENTROPY_IGNORE_IDX = -100 |
|
|
|
|
|
DEFAULT = 0 |
|
INSTRUCTION = 1 |
|
INPUT = 2 |
|
RESPONSE = 3 |
|
|
|
|
|
class ColoringAlpacaDataset(Dataset): |
|
""" |
|
See torchtune.datasets.alpaca.AlpacaDataset for the original implementation. |
|
|
|
Constructor now takes in a dataset path directly. |
|
|
|
This implementation returns 3 lists representing the tokens, labels, and token colors |
|
(as opposed to just the tokens & labels from the original). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
tokenizer: Tokenizer, |
|
dataset_path: str = "yahma/alpaca-cleaned", |
|
train_on_input: bool = True, |
|
**kwargs |
|
) -> None: |
|
self._data = load_dataset(dataset_path, split="train") |
|
self._tokenizer = tokenizer |
|
self.train_on_input = train_on_input |
|
self.num_colors = 4 |
|
|
|
def __len__(self): |
|
return len(self._data) |
|
|
|
def __getitem__(self, index: int) -> Tuple[List[int], List[int], List[int]]: |
|
sample = self._data[index] |
|
|
|
return self._transform( |
|
instruction=sample["instruction"], |
|
input=sample["input"], |
|
output=sample["output"], |
|
) |
|
|
|
def _transform( |
|
self, instruction: str, input: str, output: str |
|
) -> Tuple[List[int], List[int], List[int]]: |
|
""" |
|
Split a sample on ``response`` tag to create input and labels. |
|
|
|
Args: |
|
instruction (str): Instruction text. |
|
input (str): Input text. Can be an empty string. Determines the prompt generation template |
|
used. |
|
output (str): Response text. |
|
|
|
Returns: |
|
Tuple of encoded inputs, labels, token colors. |
|
""" |
|
prompt = self._generate_prompt(instruction, input) |
|
|
|
|
|
colors = [] |
|
tokenized = [] |
|
labels = [] |
|
is_first = True |
|
for token_type, text in prompt: |
|
tokenized_part = self._tokenizer.encode( |
|
text=text, add_bos=is_first, add_eos=False |
|
) |
|
is_first = False |
|
|
|
tokenized += tokenized_part |
|
colors += [token_type] * len(tokenized_part) |
|
if not self.train_on_input: |
|
labels += [CROSS_ENTROPY_IGNORE_IDX] * len(tokenized_part) |
|
else: |
|
labels += tokenized_part |
|
|
|
|
|
tokenized_part = self._tokenizer.encode( |
|
text=output, add_bos=False, add_eos=True |
|
) |
|
tokenized += tokenized_part |
|
colors += [RESPONSE] * len(tokenized_part) |
|
labels += tokenized_part |
|
|
|
assert len(tokenized) == len(labels) |
|
assert len(tokenized) == len(colors) |
|
|
|
return tokenized, labels, colors |
|
|
|
def _generate_prompt(self, instruction: str, input: str) -> List[Tuple[(int, str)]]: |
|
""" |
|
Generate prompt from instruction and input. |
|
|
|
Args: |
|
instruction (str): Instruction text. |
|
input (str): Input text. |
|
|
|
Returns: |
|
List of (int, templated text) |
|
""" |
|
if input: |
|
return [ |
|
(DEFAULT, ( |
|
"Below is an instruction that describes a task, paired with an input that provides further context. " |
|
"Write a response that appropriately completes the request.\n\n" |
|
"### Instruction:\n" |
|
)), |
|
(INSTRUCTION, instruction), |
|
(DEFAULT, "\n\n### Input:\n"), |
|
(INPUT, input), |
|
(DEFAULT, "\n\n### Response:\n"), |
|
] |
|
else: |
|
return [ |
|
(DEFAULT, ( |
|
"Below is an instruction that describes a task. " |
|
"Write a response that appropriately completes the request.\n\n" |
|
"### Instruction:\n" |
|
)), |
|
(INSTRUCTION, instruction), |
|
(DEFAULT, "\n\n### Response:\n"), |
|
] |
|
|
|
|
|
|
|
TokenPair = Tuple[List[int], List[int], List[int]] |
|
|
|
|
|
def padded_collate( |
|
batch: List[TokenPair], |
|
padding_idx: int = 0, |
|
ignore_idx: int = -100, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
input_ids = pad_sequence( |
|
[torch.tensor(x[0]) for x in batch], |
|
batch_first=True, |
|
padding_value=padding_idx, |
|
) |
|
labels = pad_sequence( |
|
[torch.tensor(x[1]) for x in batch], |
|
batch_first=True, |
|
padding_value=ignore_idx, |
|
) |
|
colors = pad_sequence( |
|
[torch.tensor(x[2]) for x in batch], |
|
batch_first=True, |
|
padding_value=padding_idx, |
|
) |
|
|
|
input_ids_seq_len = input_ids.shape[-1] |
|
labels_seq_len = labels.shape[-1] |
|
colors_seq_len = colors.shape[-1] |
|
|
|
assert input_ids_seq_len == labels_seq_len |
|
assert input_ids_seq_len == colors_seq_len |
|
|
|
return input_ids, labels, colors |
|
|