Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
from enum import Enum | |
from typing import Any | |
import numpy as np | |
from pydantic import BaseModel, ConfigDict | |
from bytelatent.data.data_types import Batch, BltSequence | |
from bytelatent.data.iterators.abstract_iterator import ( | |
PydanticIteratorState, | |
StatefulIterator, | |
) | |
from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState | |
class PackingMode(str, Enum): | |
BYTES = "bytes" | |
PATCHING = "patching" | |
class PackingArgs(BaseModel): | |
model_config = ConfigDict(extra="forbid") | |
batch_size: int | |
seq_len: int | |
pad_id: int | |
max_length: int | None | |
pad_to_max_length: bool | |
enable_byte_ngrams: bool | |
packing_mode: PackingMode | |
class PackingIteratorState(PydanticIteratorState): | |
model_config = ConfigDict(extra="forbid") | |
sequence_iterator_state: SamplingIteratorState | |
packing_args: PackingArgs | |
def build(self) -> "PackingIterator": | |
return PackingIterator( | |
sequence_iterator=self.sequence_iterator_state.build(), | |
packing_args=self.packing_args, | |
) | |
def _merge_patch_seq_masks(bs: int, slen: int, mask_seqs: list[list[bool]]): | |
assert len(mask_seqs) == bs | |
lens = [len(m) for m in mask_seqs] | |
if all(all(m) for m in mask_seqs) and all(lens[0] == l for l in lens): | |
return np.ones((bs, slen), dtype=bool) | |
assert slen == max(lens) - 1, f"slen={slen} != max(lens)-1={max(lens) - 1}" | |
mask = np.zeros((bs, slen), dtype=bool) | |
for i, m in enumerate(mask_seqs): | |
if m is None: | |
print( | |
"Did not implement None mask, the mask should be True for all toks, so we need to pass that to this function." | |
) | |
raise NotImplementedError | |
mask[i][: len(mask_seqs[i]) - 1] = mask_seqs[i][1:] | |
return mask | |
def truncate_batch( | |
batch: Batch, | |
max_length: int, | |
pad_id: int, | |
pad_to_max_length: bool = False, | |
*, | |
enable_byte_ngrams: bool, | |
): | |
""" | |
Truncate the x to a given size, making sure we remove the corresponding patch sizes in patch_lenghts | |
and fixing the batch.mask. | |
batch.patch_lengths has unchanged shape | |
x,y, and mask may reduce in size | |
""" | |
if batch.patch_lengths is None: | |
return batch | |
seq_lengths = batch.patch_lengths.sum(axis=1) | |
max_length_adj = max_length + 1 | |
if np.any(seq_lengths > max_length_adj): | |
for i in range(batch.x.shape[0]): | |
if seq_lengths[i] > max_length_adj: | |
# Find id of patch that tips over max_length + 1 | |
count, j = 0, 0 | |
while count + batch.patch_lengths[i, j] <= max_length_adj: | |
count += batch.patch_lengths[i, j] | |
j += 1 | |
# Edit the batch | |
assert j < batch.patch_lengths.shape[1] | |
batch.x[i, max_length:] = pad_id | |
batch.y[i, max_length:] = pad_id | |
if batch.mask is not None: | |
batch.mask[i, max_length:] = False | |
batch.patch_lengths[i, j:] = 0 | |
batch.patch_lengths[i, j] = max_length_adj - count | |
# Truncate if necessary. | |
if max_length < batch.x.shape[1]: | |
batch.x = batch.x[:, :max_length] | |
batch.y = batch.y[:, :max_length] | |
if batch.mask is not None: | |
batch.mask = batch.mask[:, :max_length] | |
# Right pad to max_length if necessary | |
elif pad_to_max_length: | |
if batch.x.shape[1] < max_length: | |
# NOTE: this has to be done on an actual patch. | |
non_zero_indices = (batch.patch_lengths != 0).sum(axis=1) - 1 | |
non_zero_indices = np.maximum(0, non_zero_indices) | |
batch.patch_lengths[range(len(batch.patch_lengths)), non_zero_indices] += ( | |
max_length - batch.x.shape[1] | |
) | |
# TODO: We could get rid of many of these complications by moving this funciton directly in the dataloader. | |
x = np.full((batch.x.shape[0], max_length), pad_id, dtype=batch.x.dtype) | |
x[:, : batch.x.shape[1]] = batch.x | |
batch.x = x | |
if batch.y.shape[1] < max_length: | |
y = np.full((batch.y.shape[0], max_length), pad_id, dtype=batch.y.dtype) | |
y[:, : batch.y.shape[1]] = batch.y | |
batch.y = y | |
if batch.mask is not None and batch.mask.shape[1] < max_length: | |
mask = np.full( | |
(batch.mask.shape[0], max_length), False, dtype=batch.mask.dtype | |
) | |
mask[:, : batch.mask.shape[1]] = batch.mask | |
batch.mask = mask | |
assert batch.x.shape[1] <= max_length | |
assert batch.y.shape[1] <= max_length | |
assert batch.mask is None or batch.mask.shape[1] <= max_length | |
assert np.all(max_length_adj - batch.patch_lengths.sum(axis=1) == 0) | |
if pad_to_max_length: | |
assert batch.x.shape[1] == max_length | |
assert batch.y.shape[1] == max_length | |
assert batch.mask is None or batch.mask.shape[1] == max_length | |
if enable_byte_ngrams: | |
raise NotImplementedError() | |
# (num_ngram, batch_size, seq_len) | |
ngram_ids = np.array(tokenizer.encode_token_ngrams(batch.x)) | |
assert ngram_ids.shape[2] == batch.x.shape[1] | |
else: | |
ngram_ids = None | |
batch.ngram_ids = ngram_ids | |
class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): | |
def __init__( | |
self, | |
sequence_iterator: StatefulIterator[BltSequence, Any], | |
*, | |
packing_args: PackingArgs, | |
): | |
self.sequence_iterator = sequence_iterator | |
self.packing_args = packing_args | |
def get_state(self): | |
return PackingIteratorState( | |
sequence_iterator_state=self.sequence_iterator.get_state(), | |
packing_args=self.packing_args, | |
) | |
def create_iter(self): | |
if self.packing_args.packing_mode == PackingMode.BYTES: | |
return self._create_iter_from_bytes() | |
elif self.packing_args.packing_mode == PackingMode.PATCHING: | |
return self._create_iter_from_patch_lengths() | |
else: | |
raise ValueError(f"Invalid patching mode: {self.packing_args.packing_mode}") | |
def _create_iter_from_bytes(self): | |
sequence_iter = self.sequence_iterator.create_iter() | |
batch_size = self.packing_args.batch_size | |
pad_id = self.packing_args.pad_id | |
seq_len = self.packing_args.seq_len | |
while True: | |
tokens: list[list[int]] = [] | |
masks: list[list[bool]] = [] | |
stop_iteration = False | |
try: | |
for _ in range(self.packing_args.batch_size): | |
sequence = next(sequence_iter) | |
_tokens = sequence.tokens | |
_mask = sequence.mask | |
assert ( | |
sequence.patch_lengths is None | |
), "patch_lengths should not be used in byte packing" | |
tokens.append(_tokens) | |
masks.append(_mask) | |
except StopIteration: | |
# At this point, there will be no new sequences, so we need to stop | |
# after yielding the already accumulated data (one batch). | |
# In this case, either: | |
# 1. We have a complete batch, so things go as normal | |
# 2. We have an incomplete batch, but due to creating a right sized batch, | |
# then filling the values in, this case is automatically handled. | |
stop_iteration = True | |
x = np.full((batch_size, seq_len), fill_value=pad_id) | |
y = np.full((batch_size, seq_len), fill_value=pad_id) | |
m = np.zeros((batch_size, seq_len), dtype=np.bool) | |
for i, tok_seq in enumerate(tokens): | |
x[i, : len(tok_seq)] = tok_seq | |
y[i, : len(tok_seq) - 1] = tok_seq[1:] | |
m[i, : len(tok_seq)] = masks[i] | |
batch = Batch(x=x, y=y, mask=m) | |
assert ( | |
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum() | |
), f"{np.sum(x != pad_id)} != {batch.mask.sum()}" | |
yield batch | |
if stop_iteration: | |
break | |
def _create_iter_from_patch_lengths(self): | |
sequence_iter = self.sequence_iterator.create_iter() | |
batch_size = self.packing_args.batch_size | |
pad_id = self.packing_args.pad_id | |
seq_len = self.packing_args.seq_len | |
pad_to_max_length = self.packing_args.pad_to_max_length | |
enable_byte_ngrams = self.packing_args.enable_byte_ngrams | |
max_length = self.packing_args.max_length | |
assert max_length is not None | |
final_leftover_batch = False | |
while True: | |
tokens: list[list[int]] = [] | |
masks: list[list[bool]] = [] | |
patch_lengths: list[list[int]] = [] | |
stop_iteration = False | |
try: | |
for _ in range(self.packing_args.batch_size): | |
sequence = next(sequence_iter) | |
_tokens = sequence.tokens | |
_mask = sequence.mask | |
_patch_lengths = sequence.patch_lengths | |
assert ( | |
_patch_lengths is not None | |
), "patch lengths are required for packing based on patches." | |
# Reminder: seq_len is in terms of patches | |
assert len(sequence.patch_lengths) == self.packing_args.seq_len | |
last_patch_length = 0 | |
if _patch_lengths[0] > 1: | |
last_patch_length = _patch_lengths[-1] | |
_patch_lengths[0] -= 1 | |
_patch_lengths = [1] + _patch_lengths[:-1] | |
tokens.append(_tokens[: len(_tokens) - last_patch_length]) | |
masks.append(_mask[: len(_mask) - last_patch_length]) | |
patch_lengths.append(_patch_lengths) | |
except StopIteration: | |
stop_iteration = True | |
if len(tokens) == 0 and stop_iteration: | |
break | |
x_patch_lengths = np.array(patch_lengths) | |
assert ( | |
x_patch_lengths.shape[1] == seq_len | |
), f"{x_patch_lengths.shape[1]} vs {seq_len}" | |
# pad batch to same length | |
tok_seq_len = max([len(toks) for toks in tokens]) - 1 | |
x = np.full((batch_size, tok_seq_len), fill_value=pad_id) | |
y = np.full((batch_size, tok_seq_len), fill_value=pad_id) | |
for i, tok_seq in enumerate(tokens): | |
x[i, : len(tok_seq) - 1] = tok_seq[:-1] | |
y[i, : len(tok_seq) - 1] = tok_seq[1:] | |
# Adjust patch lengths to match x | |
x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1) | |
if x_patch_lengths.shape[0] < batch_size: | |
if final_leftover_batch: | |
raise ValueError( | |
"There should only be one partial batch, but found multiple" | |
) | |
final_leftover_batch = True | |
assert len(masks) == len(x_patch_lengths) | |
n_missing = batch_size - x_patch_lengths.shape[0] | |
# Repeat the last patch length to validly pad it out, but | |
# update the mask to ignore the row | |
x_patch_lengths = np.vstack( | |
[ | |
x_patch_lengths, | |
np.repeat(x_patch_lengths[-1:, :], n_missing, axis=0), | |
] | |
) | |
for _ in range(n_missing): | |
masks.append([0] * tok_seq_len) | |
assert len(masks) == batch_size | |
assert x_patch_lengths.shape == ( | |
batch_size, | |
seq_len, | |
), f"{x_patch_lengths.shape} vs {(batch_size, seq_len)}" | |
if enable_byte_ngrams: | |
raise NotImplementedError() | |
else: | |
ngram_ids = None | |
batch = Batch( | |
x=x, | |
y=y, | |
patch_lengths=x_patch_lengths, | |
ngram_ids=ngram_ids, | |
mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks), | |
) | |
assert ( | |
x_patch_lengths.sum() == x.size + batch_size | |
), f"{x_patch_lengths.sum()} != {x.size + batch_size}" | |
assert ( | |
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum() | |
), f"{np.sum(x != pad_id)} != {batch.mask.sum()}" | |
assert np.all( | |
x_patch_lengths[:, 0] == 1 | |
), f"first patch should always be 1, {x_patch_lengths[:, 0]}" | |
# cuda_gb_allocated = (torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024) | |
# cuda_gb_reserved = torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024 | |
# print(f"dataloader cuda_gb_allocated: {cuda_gb_allocated}, cuda_gb_reserved: {cuda_gb_reserved}") | |
truncate_batch( | |
batch, | |
max_length=max_length, | |
pad_id=pad_id, | |
pad_to_max_length=pad_to_max_length, | |
enable_byte_ngrams=enable_byte_ngrams, | |
) | |
yield batch | |
if stop_iteration: | |
break | |