play_with_baby_llama2 / tinystories.py
Ritori's picture
Upload folder using huggingface_hub
613e5d1
raw
history blame
6.2 kB
"""
Download, preprocess and serve the TinyStories dataset as a DataLoader.
"""
import argparse
import glob
import json
import os
import random
from typing import List
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import requests
import torch
import torch.distributed as dist
from tqdm import tqdm
from tokenizer import Tokenizer
DATA_CACHE_DIR = "data"
def download_file(url: str, fname: str, chunk_size=1024):
"""Helper function to download a file from a given url"""
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with open(fname, "wb") as file, tqdm(
desc=fname,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)
def download():
"""Downloads the dataset to disk."""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
# download the TinyStories dataset, unless it's already downloaded
data_url = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz"
data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz")
if not os.path.exists(data_filename):
print(f"Downloading {data_url} to {data_filename}...")
download_file(data_url, data_filename)
else:
print(f"{data_filename} already exists, skipping download...")
# unpack the tar.gz file into all the data shards (json files)
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
if not os.path.exists(data_dir):
os.makedirs(data_dir, exist_ok=True)
print(f"Unpacking {data_filename}...")
os.system(f"tar -xzf {data_filename} -C {data_dir}")
else:
print(f"{data_dir} already exists, skipping unpacking...")
# print a single example just for debugging and such
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
with open(shard_filenames[0], "r") as f:
data = json.load(f)
print("Download done.")
print(f"Number of shards: {len(shard_filenames)}")
print(f"Example story:\n{data[0]}")
def pretokenize():
enc = Tokenizer()
def process_shard(shard):
with open(shard, "r") as f:
data = json.load(f)
all_tokens = []
for example in tqdm(data):
text = example["story"]
text = text.strip() # get rid of leading/trailing whitespace
tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS
all_tokens.extend(tokens)
# convert to uint16 nparray
all_tokens = np.array(all_tokens, dtype=np.uint16)
# write to disk
tokenized_filename = shard.replace(".json", ".bin")
with open(tokenized_filename, "wb") as f:
f.write(all_tokens.tobytes())
print(f"Saved {tokenized_filename}")
# iterate the shards and tokenize all of them one by one
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
# process all the shards in a threadpool
with ThreadPoolExecutor(max_workers=8) as executor:
executor.map(process_shard, shard_filenames)
print("Done.")
class PretokDataset(torch.utils.data.IterableDataset):
"""Loads pretokenized examples from disk and yields them as PyTorch tensors."""
def __init__(self, split, max_seq_len):
super().__init__()
self.split = split
self.max_seq_len = max_seq_len
def __iter__(self):
# get worker info within a DataLoader
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id if worker_info else 0
# get DDP rank info
rank = dist.get_rank() if dist.is_initialized() else 0
# combine the worker_id and worker_rank to create a unique seed for rng
seed = 42 + worker_id + 1337 * rank
rng = random.Random(seed)
print(f"Created a PretokDataset with rng seed {seed}")
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.bin")))
# train/test split. let's use only shard 0 for test split, rest train
shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1]
while True:
rng.shuffle(shard_filenames)
for shard in shard_filenames:
# open the dataset for reading but keep it on disk with memmap
m = np.memmap(shard, dtype=np.uint16, mode="r")
num_batches = len(m) // self.max_seq_len
num_batches -= 1 # drop the last partial batch
assert num_batches > 0, "this shard is way too small? investigate."
ixs = list(range(num_batches))
rng.shuffle(ixs)
for ix in ixs:
start = ix * self.max_seq_len
end = start + self.max_seq_len + 1
# calling .astype will copy the data into a new numpy array, now in RAM
chunk = torch.from_numpy((m[start:end]).astype(np.int64))
x = chunk[:-1]
y = chunk[1:]
yield x, y
class Task:
@staticmethod
def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
ds = PretokDataset(split, max_seq_len)
dl = torch.utils.data.DataLoader(
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
)
for x, y in dl:
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
yield x, y
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("stage", type=str, choices=["download", "train_tokenizer", "pretokenize"])
args = parser.parse_args()
# depending on the stage call the appropriate function
fun = {
"download": download,
"pretokenize": pretokenize,
}
fun[args.stage]()