laurencer's picture
Step 6000
261dbc8 verified
raw
history blame
5.51 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Tuple
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from datasets import load_dataset
# Not ideal to import this type here but it's needed for the transform function
from torchtune.modules import Tokenizer
CROSS_ENTROPY_IGNORE_IDX = -100
DEFAULT = 0
INSTRUCTION = 1
INPUT = 2
RESPONSE = 3
class ColoringAlpacaDataset(Dataset):
"""
See torchtune.datasets.alpaca.AlpacaDataset for the original implementation.
Constructor now takes in a dataset path directly.
This implementation returns 3 lists representing the tokens, labels, and token colors
(as opposed to just the tokens & labels from the original).
"""
def __init__(
self,
tokenizer: Tokenizer,
dataset_path: str = "yahma/alpaca-cleaned",
train_on_input: bool = True,
**kwargs
) -> None:
self._data = load_dataset(dataset_path, split="train")
self._tokenizer = tokenizer
self.train_on_input = train_on_input
self.num_colors = 4 # matches the above usage of DEFAULT, INSTRUCTION, INPUT, RESPONSE
def __len__(self):
return len(self._data)
def __getitem__(self, index: int) -> Tuple[List[int], List[int], List[int]]:
sample = self._data[index]
return self._transform(
instruction=sample["instruction"],
input=sample["input"],
output=sample["output"],
)
def _transform(
self, instruction: str, input: str, output: str
) -> Tuple[List[int], List[int], List[int]]:
"""
Split a sample on ``response`` tag to create input and labels.
Args:
instruction (str): Instruction text.
input (str): Input text. Can be an empty string. Determines the prompt generation template
used.
output (str): Response text.
Returns:
Tuple of encoded inputs, labels, token colors.
"""
prompt = self._generate_prompt(instruction, input)
# First handle the prompt
colors = []
tokenized = []
labels = []
is_first = True
for token_type, text in prompt:
tokenized_part = self._tokenizer.encode(
text=text, add_bos=is_first, add_eos=False
)
is_first = False
tokenized += tokenized_part
colors += [token_type] * len(tokenized_part)
if not self.train_on_input:
labels += [CROSS_ENTROPY_IGNORE_IDX] * len(tokenized_part)
else:
labels += tokenized_part
# Now add the response tokens
tokenized_part = self._tokenizer.encode(
text=output, add_bos=False, add_eos=True
)
tokenized += tokenized_part
colors += [RESPONSE] * len(tokenized_part)
labels += tokenized_part
assert len(tokenized) == len(labels)
assert len(tokenized) == len(colors)
return tokenized, labels, colors
def _generate_prompt(self, instruction: str, input: str) -> List[Tuple[(int, str)]]:
"""
Generate prompt from instruction and input.
Args:
instruction (str): Instruction text.
input (str): Input text.
Returns:
List of (int, templated text)
"""
if input:
return [
(DEFAULT, (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n"
)),
(INSTRUCTION, instruction),
(DEFAULT, "\n\n### Input:\n"),
(INPUT, input),
(DEFAULT, "\n\n### Response:\n"),
]
else:
return [
(DEFAULT, (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n"
)),
(INSTRUCTION, instruction),
(DEFAULT, "\n\n### Response:\n"),
]
# TokenPair is a pair (tuple) of three lists: tokenized text inputs, labels, colors.
TokenPair = Tuple[List[int], List[int], List[int]]
def padded_collate(
batch: List[TokenPair],
padding_idx: int = 0,
ignore_idx: int = -100,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
input_ids = pad_sequence(
[torch.tensor(x[0]) for x in batch],
batch_first=True,
padding_value=padding_idx,
)
labels = pad_sequence(
[torch.tensor(x[1]) for x in batch],
batch_first=True,
padding_value=ignore_idx,
)
colors = pad_sequence(
[torch.tensor(x[2]) for x in batch],
batch_first=True,
padding_value=padding_idx,
)
input_ids_seq_len = input_ids.shape[-1]
labels_seq_len = labels.shape[-1]
colors_seq_len = colors.shape[-1]
assert input_ids_seq_len == labels_seq_len
assert input_ids_seq_len == colors_seq_len
return input_ids, labels, colors