LightGPT / data.py
Andrew DalPino
Broad improvements
ab12a97
raw
history blame
7.51 kB
import random
from os import path
from copy import deepcopy
from datasets import load_dataset
from tiktoken import Encoding
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import IterableDataset, Dataset
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
class Fineweb(IterableDataset):
DATASET_NAME = "HuggingFaceFW/fineweb"
TEST_SPLIT_PROPORTION = 0.005
NUM_SHARDS = 1024
PADDING_INDEX = -100
def __init__(
self,
tokenizer: Encoding,
root_path: str = "./dataset",
subset: str | None = "sample-10BT",
split: str = "train",
tokens_per_sample: int = 1024,
samples_per_epoch: int = 4096,
num_processes: int = 8,
):
super().__init__()
if subset != None:
if subset not in ("sample-10BT", "sample-100BT", "sample-350BT"):
raise ValueError(f"Invalid subset, {subset} given.")
if split not in ("train", "test"):
raise ValueError(f"Split must be either train or test, {split} given.")
if tokens_per_sample < 1:
raise ValueError(f"Tokens per sample must be greater than 0.")
if samples_per_epoch < 1:
raise ValueError(f"Samples per epoch must be greater than 0.")
dataset_name = f"fineweb-{subset}" if subset != None else "fineweb"
train_path = path.join(root_path, f"{dataset_name}-train-{tokenizer.name}.bin")
test_path = path.join(root_path, f"{dataset_name}-test-{tokenizer.name}.bin")
self.tokenizer = tokenizer
if not path.exists(train_path) or not path.exists(test_path):
dataset = load_dataset(
self.DATASET_NAME,
name=subset,
num_proc=num_processes,
split="train",
).map(
self.tokenize,
desc="Tokenizing",
remove_columns=["text", "token_count"],
num_proc=num_processes,
)
tokenized_splits = dataset.train_test_split(
test_size=self.TEST_SPLIT_PROPORTION
)
for split, dataset in tokenized_splits.items():
bin_path = train_path if split == "train" else test_path
total_length = np.sum(dataset["length"], dtype=np.uint64)
bin_out = np.memmap(
bin_path, dtype=np.uint16, mode="w+", shape=total_length
)
index = 0
for i in tqdm(range(self.NUM_SHARDS), desc="Writing"):
batch = dataset.shard(
num_shards=self.NUM_SHARDS, index=i, contiguous=True
).with_format("numpy")
token_batch = np.concatenate(batch["tokens"])
n = len(token_batch)
bin_out[index : index + n] = token_batch
index += n
bin_out.flush()
bin_file_path = train_path if split == "train" else test_path
memmap = np.memmap(bin_file_path, dtype=np.uint16, mode="r")
self.memmap = memmap
self.max_start = len(memmap) - (tokens_per_sample + 1)
self.tokens_per_sample = tokens_per_sample
self.samples_per_epoch = samples_per_epoch
@property
def vocabulary_size(self) -> int:
return self.tokenizer.max_token_value + 1
@property
def eos_index(self) -> int:
return self.tokenizer.eot_token
def tokenize(self, sample: dict) -> dict:
tokens = self.tokenizer.encode_ordinary(sample["text"])
tokens.append(self.tokenizer.eot_token)
return {
"tokens": tokens,
"length": len(tokens),
}
def __iter__(self):
for i in range(self.samples_per_epoch):
start = random.randint(0, self.max_start)
end = start + self.tokens_per_sample
x = self.memmap[start:end]
y = self.memmap[start + 1 : end + 1]
x = x.astype(np.int64)
y = y.astype(np.int64)
assert x.shape == y.shape, "Sample / label shape mismatch."
yield x, y
class Alpaca(Dataset):
DATASET_NAME = "tatsu-lab/alpaca"
PADDING_INDEX = -100
PROMPT_TEMPLATE = (
"Below is an instruction that describes a task. Write a response that "
"appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n"
"### Response:\n"
)
PROMPT_TEMPLATE_WITH_INPUT = (
"Below is an instruction that describes a task, paired with an input "
"that provides further context. Write a response that appropriately "
"completes the request.\n\n"
"### Input:\n{input}\n\n"
"### Instruction:\n{instruction}\n\n"
"### Response:\n"
)
RESPONSE_TEMPLATE = "{output}"
def __init__(
self,
tokenizer: Encoding,
max_tokens_per_sample: int = 1024,
mask_input: bool = False,
):
super().__init__()
if max_tokens_per_sample < 1:
raise ValueError(
f"Max tokens per sample must be greater than 0, {max_tokens_per_sample} given."
)
self.tokenizer = tokenizer
self.dataset = load_dataset(self.DATASET_NAME, split="train")
self.max_tokens_per_sample = max_tokens_per_sample
self.mask_input = mask_input
@property
def vocabulary_size(self) -> int:
return self.tokenizer.max_token_value + 1
@property
def eos_index(self) -> int:
return self.tokenizer.eot_token
def collate(self, batch: list) -> tuple[Tensor, Tensor]:
"""Custom collate function adds left padding to batched samples."""
sample, labels = [], []
for x, y in batch:
sample.append(x)
labels.append(y)
x = pad_sequence(
sample,
batch_first=True,
padding_value=self.PADDING_INDEX,
padding_side="left",
)
y = pad_sequence(
labels,
batch_first=True,
padding_value=self.PADDING_INDEX,
padding_side="left",
)
assert x.shape == y.shape, "Sample / label batch shape mismatch."
return x, y
def __getitem__(self, index: int):
row = self.dataset[index]
has_input = len(row["input"]) > 0
if has_input:
text = self.PROMPT_TEMPLATE_WITH_INPUT.format(
input=row["input"], instruction=row["instruction"]
)
else:
text = self.PROMPT_TEMPLATE.format(instruction=row["instruction"])
tokens = self.tokenizer.encode_ordinary(text)
sample = deepcopy(tokens)
if self.mask_input:
labels = [self.PADDING_INDEX] * len(tokens)
else:
labels = deepcopy(tokens)
text = self.RESPONSE_TEMPLATE.format(output=row["output"])
tokens = self.tokenizer.encode_ordinary(text)
tokens.append(self.tokenizer.eot_token)
sample.extend(tokens)
labels.extend(tokens)
end = min(len(sample), self.max_tokens_per_sample + 1)
x = torch.tensor(sample[0 : end - 1], dtype=torch.int64)
y = torch.tensor(labels[1:end], dtype=torch.int64)
assert x.shape == y.shape, "Sample / label shape mismatch."
return x, y
def __len__(self):
return len(self.dataset)