import os from unittest.mock import MagicMock import requests from torch.utils.data import IterableDataset def train_tokenizer(destination_path): destination_path.mkdir(parents=True, exist_ok=True) # download the tiny shakespeare dataset input_file_path = destination_path / "input.txt" if not input_file_path.exists(): data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" with open(input_file_path, "w") as f: f.write(requests.get(data_url).text) from lit_llama import Tokenizer Tokenizer.train( input=input_file_path, destination=destination_path, vocab_size=100, ) return destination_path / "tokenizer.model" def test_packed_dataset(tmp_path): tokenizer_path = train_tokenizer(tmp_path) from lit_llama import Tokenizer tokenizer = Tokenizer(tokenizer_path) texts = [ "The moment of truth is upon us.", "Time to open the fridge." ] from lit_llama.packed_dataset import PackedDatasetBuilder, PackedDataset, HDR_SIZE block_size = 10 n_blocks = 2 chunk_size = block_size * n_blocks builder = PackedDatasetBuilder( outdir=tmp_path, prefix="packed_dataset", chunk_size=chunk_size, sep_token=tokenizer.bos_id, dtype="auto", vocab_size=100, ) text_ids = [] for text in texts: text_ids = tokenizer.encode(text) assert text_ids[0] == tokenizer.bos_id builder.add_array(text_ids) filenames = builder.filenames assert len(filenames) == 2 assert os.path.basename(filenames[0]) == "packed_dataset_0000000000.bin" assert os.path.basename(filenames[1]) == "packed_dataset_0000000001.bin" import numpy as np ex_tokenized = [ tokenizer.encode(text).numpy().astype(builder.dtype) for text in texts ] ex_tokenized = np.concatenate(ex_tokenized) ex_tokenized = ex_tokenized[:2 * chunk_size] for filename, el in zip(filenames, np.array_split(ex_tokenized, 2)): mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) count = len(mmap) // np.dtype(builder.dtype).itemsize arr = np.frombuffer( mmap, dtype=builder.dtype, count=count, offset=0 ) where_bos = np.where(arr == tokenizer.bos_id) # we expect two BOS tokens, one per file assert len(where_bos) == 1 assert np.array_equal(arr, el) dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, shuffle=False) ex_split = np.array_split(ex_tokenized, ex_tokenized.shape[0] // block_size) for item, el in zip(dataset, ex_split): assert np.array_equal(item, el) dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, seed=12345) for i, item in enumerate(dataset): block_idxs = iter(dataset)._block_idxs assert np.array_equal(item, ex_split[block_idxs[i]]) dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, seed=12345, wrap=True) for i, item in enumerate(dataset): if i > 24: break dataset = PackedDataset(filenames=filenames, n_chunks=1, block_size=block_size, seed=12345) for i, item in enumerate(dataset): block_idxs = iter(dataset)._block_idxs chunk_idx = i // n_blocks * n_blocks assert np.array_equal(item, ex_split[chunk_idx + block_idxs[i % n_blocks]]) block_size_ = block_size // 2 ex_split = np.array_split(ex_tokenized, ex_tokenized.shape[0] // block_size_) dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size_, seed=12345) for i, item in enumerate(dataset): block_idxs = iter(dataset)._block_idxs assert np.array_equal(item, ex_split[block_idxs[i]]) block_size_ = block_size // 3 n_chunks = 2 ex_chunks = np.split(ex_tokenized, n_chunks) n_splits = ex_tokenized.shape[0] // n_chunks // block_size_ ex_splits = [np.split(el[:n_splits * block_size_], n_splits) for el in ex_chunks] ex_split = sum(ex_splits, []) dataset = PackedDataset(filenames=filenames, n_chunks=n_chunks, block_size=block_size_, seed=12345) for i, item in enumerate(dataset): block_idxs = iter(dataset)._block_idxs assert np.array_equal(item, ex_split[block_idxs[i]]) class SimpleDataset(IterableDataset): def __init__(self, start, end): super().__init__() self._start = start self._end = end def __iter__(self): return iter(range(self._start, self._end)) def test_combined_dataset(tmp_path): from lit_llama.packed_dataset import CombinedDataset dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) dataset = CombinedDataset(datasets=[dataset1, dataset2], weights=[1.0, 0.0], seed=12345) res = [el for el in dataset] assert res == list(range(0, 10)) dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) dataset = CombinedDataset(datasets=[dataset1, dataset2], weights=[0.0, 1.0], seed=12345) res = [el for el in dataset] assert res == list(range(10, 20)) dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) dataset = CombinedDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) res = [el for el in dataset] assert 9 in res or 19 in res if len(res) > 10: assert 0 in res and 10 in res def test_sharded_packed_dataset(monkeypatch): import lit_llama.packed_dataset from lit_llama.packed_dataset import PackedDataset dataset_iterator_mock = MagicMock() monkeypatch.setattr(lit_llama.packed_dataset, "PackedDatasetIterator", dataset_iterator_mock) filenames = [str(i) for i in range(10)] # world_size = 1, rank = 0 iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2)) assert dataset_iterator_mock.call_args[1]["filenames"] == filenames dataset_iterator_mock.reset_mock() # world_size = 2, rank = 0 iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=2, process_rank=0)) assert dataset_iterator_mock.call_args[1]["filenames"] == ["0", "2", "4", "6", "8"] dataset_iterator_mock.reset_mock() # world_size = 2, rank = 1 iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=2, process_rank=1)) assert dataset_iterator_mock.call_args[1]["filenames"] == ["1", "3", "5", "7", "9"] dataset_iterator_mock.reset_mock() # world_size = 3, rank = 0 (dataset size not cleanly divisible by world size) iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=3, process_rank=0)) assert dataset_iterator_mock.call_args[1]["filenames"] == ["0", "3", "6"] dataset_iterator_mock.reset_mock() # world_size = 3, rank = 1 (dataset size not cleanly divisible by world size) iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=3, process_rank=1)) assert dataset_iterator_mock.call_args[1]["filenames"] == ["1", "4", "7"] dataset_iterator_mock.reset_mock() # world_size = 3, rank = 2 (dataset size not cleanly divisible by world size) iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=3, process_rank=2)) assert dataset_iterator_mock.call_args[1]["filenames"] == ["2", "5", "8"]