TAPA / lit_llama /packed_dataset.py
xuxw98's picture
Upload 58 files
7d52396
raw
history blame
7.88 kB
# Very loosely inspired by indexed_dataset in Fairseq, Megatron
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
import os
import struct
import random
import numpy as np
import torch
from torch.utils.data import IterableDataset, get_worker_info
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float32,
7: np.float64,
8: np.uint16,
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
HDR_MAGIC = b"LITPKDS"
HDR_SIZE = 24 # bytes
class PackedDataset(IterableDataset):
def __init__(self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0):
self._filenames = filenames
self._n_chunks = n_chunks
self._block_size = block_size
self._seed = seed
self._shuffle = shuffle
self._wrap = wrap
self._num_processes = num_processes
self._process_rank = process_rank
def __iter__(self):
worker_info = get_worker_info()
num_workers = worker_info.num_workers if worker_info is not None else 1
worker_id = worker_info.id if worker_info is not None else 0
num_shards = num_workers * self._num_processes
shard_id = self._process_rank * num_workers + worker_id
max_num_files = len(self._filenames) // num_shards * num_shards
filenames = self._filenames[shard_id : max_num_files : num_shards]
return PackedDatasetIterator(
filenames=filenames,
n_chunks=self._n_chunks,
block_size=self._block_size,
seed=self._seed,
shuffle=self._shuffle,
wrap=self._wrap,
)
class PackedDatasetBuilder(object):
def __init__(
self,
outdir,
prefix,
chunk_size,
sep_token,
dtype="auto",
vocab_size=None,
):
if dtype == "auto":
if vocab_size is None:
raise ValueError("vocab_size cannot be None when dtype='auto'")
if vocab_size is not None and vocab_size < 65500:
self._dtype = np.uint16
else:
self._dtype = np.int32
else:
self._dtype = dtype
self._counter = 0
self._chunk_size = chunk_size
self._outdir = outdir
self._prefix = prefix
self._sep_token = sep_token
self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
self._arr.fill(self._sep_token)
self._idx = 0
self._version = 1
self._filenames = []
def _write_chunk(self):
filename = f"{self._prefix}_{self._counter:010d}.bin"
filename = os.path.join(self._outdir, filename)
with open(filename, "wb") as f:
f.write(HDR_MAGIC)
f.write(struct.pack("<Q", self._version))
f.write(struct.pack("<B", code(self._dtype)))
f.write(struct.pack("<Q", self._chunk_size))
f.write(self._arr.tobytes(order="C"))
self._filenames.append(filename)
self._counter += 1
self._arr.fill(self._sep_token)
self._idx = 0
@property
def dtype(self):
return self._dtype
@property
def filenames(self):
return self._filenames.copy()
def add_array(self, arr):
while self._idx + arr.shape[0] > self._chunk_size:
part_len = self._chunk_size - self._idx
self._arr[self._idx : self._idx + part_len] = arr[:part_len]
self._write_chunk()
arr = arr[part_len:]
arr_len = arr.shape[0]
self._arr[self._idx : self._idx + arr_len] = arr
self._idx += arr_len
def write_reminder(self):
self._write_chunk()
class PackedDatasetIterator:
def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
self._seed = seed
self._shuffle = shuffle
self._rng = np.random.default_rng(seed) if shuffle else None
self._block_idxs = None
self._wrap = wrap
# TODO: instead of filenames, we could have a single text stream
# (or text file) with the sequence of all files to be
# fetched/loaded.
self._filenames = filenames
self._file_idx = 0
self._n_chunks = n_chunks
self._dtype = None
self._block_size = block_size
self._n_blocks = None
self._mmaps = []
self._buffers = []
self._block_idxs = []
self._curr_idx = 0
self._load_n_chunks()
def _read_header(self, path):
with open(path, "rb") as f:
magic = f.read(len(HDR_MAGIC))
assert magic == HDR_MAGIC, "File doesn't match expected format."
version = struct.unpack("<Q", f.read(8))
assert (1,) == version
(dtype_code,) = struct.unpack("<B", f.read(1))
dtype = dtypes[dtype_code]
(chunk_size,) = struct.unpack("<Q", f.read(8))
return dtype, chunk_size
def _close_mmaps(self):
for mmap in self._mmaps:
mmap._mmap.close()
def _load_n_chunks(self):
self._close_mmaps()
self._mmaps = []
self._buffers = []
if self._n_chunks > len(self._filenames[self._file_idx:]):
if not self._wrap:
raise StopIteration
else:
self._file_idx = 0
for i in range(self._n_chunks):
filename = self._filenames[self._file_idx + i]
if self._dtype is None:
self._dtype, self._chunk_size = self._read_header(
filename
)
self._n_blocks = self._chunk_size // self._block_size
# TODO: check header matches with previous files
mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
self._mmaps.append(mmap)
self._buffers.append(memoryview(mmap))
self._file_idx += self._n_chunks
n_all_blocks = self._n_chunks * self._n_blocks
self._block_idxs = (
self._rng.permutation(n_all_blocks)
if self._shuffle
else range(n_all_blocks)
)
self._curr_idx = 0
def __del__(self):
self._close_mmaps()
del self._mmaps
del self._buffers
def __iter__(self):
return self
def __next__(self):
if self._curr_idx >= len(self._block_idxs):
self._load_n_chunks()
# TODO: trigger fetching next next n_chunks if remote
block_idx = self._block_idxs[self._curr_idx]
chunk_id = block_idx // self._n_blocks
buffer = self._buffers[chunk_id]
elem_id = (block_idx % self._n_blocks) * self._block_size
offset = np.dtype(self._dtype).itemsize * elem_id
arr = np.frombuffer(
buffer, dtype=self._dtype, count=self._block_size, offset=offset
)
self._curr_idx += 1
return torch.from_numpy(arr.astype(np.int64))
class CombinedDataset(IterableDataset):
def __init__(self, datasets, seed, weights=None):
self._seed = seed
self._datasets = datasets
self._weights = weights
n_datasets = len(datasets)
if weights is None:
self._weights = [1 / n_datasets] * n_datasets
def __iter__(self):
return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
class CombinedDatasetIterator:
def __init__(self, datasets, seed, weights):
self._datasets = [iter(el) for el in datasets]
self._weights = weights
self._rng = random.Random(seed)
def __next__(self):
dataset, = self._rng.choices(self._datasets, weights=self._weights, k=1)
return next(dataset)