Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
from logging import getLogger | |
from typing import Any | |
import numpy as np | |
from pydantic import BaseModel, ConfigDict | |
from bytelatent.data.data_types import BltSequence | |
from bytelatent.data.iterators.abstract_iterator import ( | |
PydanticIteratorState, | |
StatefulIterator, | |
) | |
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator | |
from bytelatent.data.iterators.limit_iterator import LimitIterator | |
from bytelatent.data.iterators.looping_iterator import LoopingIterator | |
from bytelatent.data.iterators.preprocess_iterator import ( | |
PreprocessIterator, | |
PreprocessIteratorState, | |
) | |
logger = getLogger() | |
class SequencePackingArgs(BaseModel): | |
model_config = ConfigDict(extra="forbid") | |
output_seq_len: int | |
buffer_size: int | |
class SequenceIteratorState(PydanticIteratorState): | |
model_config = ConfigDict(extra="forbid") | |
sequence_packing_args: SequencePackingArgs | |
preprocess_iterator_state: PreprocessIteratorState | |
# If None, rng is disabled. | |
rng_state: dict[str, Any] | None | |
def build(self): | |
preprocess_iterator = self.preprocess_iterator_state.build() | |
return SequenceIterator( | |
preprocess_iterator, | |
sequence_packing_args=self.sequence_packing_args, | |
rng_state=self.rng_state, | |
) | |
def get_datafile( | |
iterator: PreprocessIterator | ArrowFileIterator | LoopingIterator | LimitIterator, | |
): | |
if isinstance(iterator, ArrowFileIterator): | |
return f"file={iterator.file_path} n_shards={len(iterator.dataset_files) if iterator.dataset_files is not None else None}" | |
elif isinstance(iterator, PreprocessIterator): | |
return get_datafile(iterator.arrow_iterator) | |
elif isinstance(iterator, LoopingIterator): | |
return get_datafile(iterator.file_iterator) | |
elif isinstance(iterator, LimitIterator): | |
return get_datafile(iterator.base_iterator) | |
else: | |
raise NotImplementedError() | |
class SequenceIterator(StatefulIterator): | |
def __init__( | |
self, | |
preprocess_iterator: PreprocessIterator, | |
*, | |
rng_state: dict[str, Any] | None, | |
sequence_packing_args: SequencePackingArgs, | |
): | |
self.preprocess_iterator = preprocess_iterator | |
self.sequence_packing_args = sequence_packing_args | |
self.output_seq_len = sequence_packing_args.output_seq_len | |
self.buffer_size = sequence_packing_args.buffer_size | |
if rng_state is None: | |
self.rng = None | |
else: | |
self.rng = np.random.default_rng() | |
self.rng.bit_generator.state = rng_state | |
def get_state(self): | |
# TODO: need to also perist the current shuffle buffer | |
return SequenceIteratorState( | |
sequence_packing_args=self.sequence_packing_args, | |
preprocess_iterator_state=self.preprocess_iterator.get_state(), | |
rng_state=None if self.rng is None else self.rng.bit_generator.state, | |
) | |
def create_iter(self): | |
example_iter = self.preprocess_iterator.create_iter() | |
n_buffer_patches = self.buffer_size * self.output_seq_len | |
patch_lengths: list[int] = [] | |
tokens: list[int] = [] | |
mask: list[bool] = [] | |
first = True | |
logger.info( | |
"Starting first buffer for: %s", | |
get_datafile(self.preprocess_iterator), | |
) | |
for example in example_iter: | |
assert example.tokens is not None | |
assert example.mask is not None | |
if self.preprocess_iterator.add_patches: | |
assert example.patch_lengths is not None | |
assert len(example.tokens) == sum(example.patch_lengths) | |
else: | |
assert example.patch_lengths is None | |
assert len(example.tokens) != 0 | |
assert len(example.mask) != 0 | |
assert len(example.tokens) == len(example.mask) | |
tokens.extend(example.tokens) | |
mask.extend(example.mask) | |
if self.preprocess_iterator.add_patches: | |
patch_lengths.extend(example.patch_lengths) | |
else: | |
# This lets the rest of the code work as expected and just yield byte seqs | |
patch_lengths.extend([1] * len(example.tokens)) | |
while len(patch_lengths) >= n_buffer_patches: | |
if first: | |
first = False | |
logger.info( | |
"First buffer complete for: %s", | |
get_datafile(self.preprocess_iterator), | |
) | |
x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape( | |
self.buffer_size, self.output_seq_len | |
) | |
seq_tokens = [] | |
seq_mask = [] | |
start_id = 0 | |
# We fix the number of patches and therefore global steps per batch | |
# so we have a variable number of tokens we need to account for | |
for num_tokens in x_patches.sum(axis=-1): | |
seq_tokens.append(tokens[start_id : start_id + num_tokens]) | |
seq_mask.append(mask[start_id : start_id + num_tokens]) | |
start_id += num_tokens | |
assert start_id == x_patches.sum() | |
# Remove what we just added from the buffer | |
patch_lengths = patch_lengths[n_buffer_patches:] | |
tokens = tokens[x_patches.sum() :] | |
mask = mask[x_patches.sum() :] | |
seq_patch_lengths: list[list[int]] = x_patches.tolist() | |
assert len(seq_patch_lengths) == self.buffer_size | |
if self.rng is None: | |
permutations = list(range(len(seq_patch_lengths))) | |
else: | |
permutations = self.rng.permutation(len(seq_patch_lengths)) | |
for idx in permutations: | |
assert len(seq_patch_lengths[idx]) == self.output_seq_len | |
assert ( | |
sum(seq_patch_lengths[idx]) | |
== len(seq_tokens[idx]) | |
== len(seq_mask[idx]) | |
), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}" | |
assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}" | |
if self.preprocess_iterator.add_patches: | |
yield BltSequence( | |
tokens=seq_tokens[idx], | |
mask=seq_mask[idx], | |
patch_lengths=seq_patch_lengths[idx], | |
) | |
else: | |
yield BltSequence( | |
tokens=seq_tokens[idx], | |
mask=seq_mask[idx], | |
patch_lengths=None, | |
) | |