Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import pytest | |
from bytelatent.data.data_types import BltSequence | |
from bytelatent.data.iterators.abstract_iterator import StatefulIterator | |
from bytelatent.data.iterators.packing_iterator import ( | |
PackingArgs, | |
PackingIterator, | |
PackingMode, | |
_merge_patch_seq_masks, | |
) | |
class DummySequenceIterator(StatefulIterator): | |
def __init__( | |
self, | |
*, | |
seq_len: int, | |
n_seqs: int, | |
patch_lengths: list[int] | None = None, | |
pad_id: int = 0, | |
): | |
self.seq_len = seq_len | |
self.n_seqs = n_seqs | |
self.patch_lengths = patch_lengths | |
self.pad_id = pad_id | |
def get_state(self): | |
raise NotImplementedError() | |
def create_iter(self): | |
for i in range(self.n_seqs): | |
if self.patch_lengths is None: | |
tokens = np.arange( | |
i * self.seq_len + 1, (i + 1) * self.seq_len + 1 | |
).tolist() | |
mask = [True] * self.seq_len # type: ignore | |
assert len(tokens) == self.seq_len | |
else: | |
n = sum(self.patch_lengths) | |
tokens = np.arange(i * n + 1, (i + 1) * n + 1).tolist() | |
assert len(tokens) == n | |
mask = [True] * n | |
assert len(mask) == len(tokens) | |
yield BltSequence( | |
tokens=tokens, | |
mask=mask, | |
patch_lengths=self.patch_lengths, | |
) | |
def create_bytes_iter(*, seq_len: int, n_seqs: int, batch_size: int, pad_id: int): | |
sequence_iterator = DummySequenceIterator(seq_len=seq_len, n_seqs=n_seqs) | |
packing_iterator = PackingIterator( | |
sequence_iterator, | |
packing_args=PackingArgs( | |
batch_size=batch_size, | |
seq_len=seq_len, | |
pad_id=pad_id, | |
packing_mode=PackingMode.BYTES, | |
max_length=None, | |
pad_to_max_length=False, | |
enable_byte_ngrams=False, | |
), | |
) | |
return packing_iterator.create_iter() | |
def create_patches_iter( | |
*, | |
seq_len: int, | |
n_seqs: int, | |
batch_size: int, | |
pad_id: int, | |
patch_lengths: list[int] | None, | |
max_length: int, | |
): | |
sequence_iterator = DummySequenceIterator( | |
# seq_len=number of bytes, which for blt/patches, is max_length since seq_len is | |
# in terms of number of patches | |
seq_len=max_length, | |
n_seqs=n_seqs, | |
patch_lengths=patch_lengths, | |
) | |
packing_iterator = PackingIterator( | |
sequence_iterator, | |
packing_args=PackingArgs( | |
batch_size=batch_size, | |
seq_len=seq_len, | |
pad_id=pad_id, | |
packing_mode=PackingMode.PATCHING, | |
max_length=max_length, | |
pad_to_max_length=True, | |
enable_byte_ngrams=False, | |
), | |
) | |
return packing_iterator.create_iter() | |
def test_last_batch_correctness_bytes(): | |
seq_len = 1024 | |
n_seqs = 10 | |
batch_size = 4 | |
pad_id = 0 | |
iterator = create_bytes_iter( | |
seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id | |
) | |
batches = [] | |
n_nonpad = 0 | |
n_nonmask = 0 | |
for b in iterator: | |
assert b.x.shape[0] == batch_size | |
assert b.x.shape[1] == seq_len | |
n_nonpad += (b.x != pad_id).sum() | |
if b.mask is None: | |
n_nonmask += b.x.size | |
else: | |
n_nonmask += b.mask.sum() | |
batches.append(b) | |
assert len(batches) == 3 | |
assert n_nonpad == n_nonmask == seq_len * n_seqs | |
# The second half of the last batch should be all pads | |
assert batches[-1].mask[2:].sum() == 0 | |
def test_edgecase_batch_correctness_bytes(): | |
seq_len = 1024 | |
n_seqs = 10 | |
batch_size = 12 | |
pad_id = 0 | |
iterator = create_bytes_iter( | |
seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id | |
) | |
batches = [] | |
n_nonpad = 0 | |
n_nonmask = 0 | |
for b in iterator: | |
assert b.x.shape[0] == batch_size | |
assert b.x.shape[1] == seq_len | |
n_nonpad += (b.x != pad_id).sum() | |
if b.mask is None: | |
n_nonmask += b.x.size | |
else: | |
n_nonmask += b.mask.sum() | |
batches.append(b) | |
assert len(batches) == 1 | |
assert n_nonpad == n_nonmask == seq_len * n_seqs | |
# The second half of the last batch should be all pads | |
assert batches[0].mask[10:].sum() == 0 | |
def test_exact_batch_correctness_bytes(): | |
seq_len = 1024 | |
n_seqs = 12 | |
batch_size = 4 | |
pad_id = 0 | |
iterator = create_bytes_iter( | |
seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id | |
) | |
batches = [] | |
n_nonpad = 0 | |
n_nonmask = 0 | |
for b in iterator: | |
assert b.x.shape[0] == batch_size | |
assert b.x.shape[1] == seq_len | |
n_nonpad += (b.x != pad_id).sum() | |
if b.mask is None: | |
n_nonmask += b.x.size | |
else: | |
n_nonmask += b.mask.sum() | |
batches.append(b) | |
assert len(batches) == 4 | |
assert n_nonpad == n_nonmask == seq_len * n_seqs | |
def test_exact_batch_correctness_patches(): | |
# First patch length is forced to be 1 | |
patch_lengths = [1, 255, 256, 256, 256] | |
# Recall: This is in terms of bytes | |
max_length = 1024 | |
# Recall: This is in terms of patches | |
seq_len = 5 | |
n_seqs = 12 | |
batch_size = 4 | |
pad_id = 0 | |
iterator = create_patches_iter( | |
seq_len=seq_len, | |
n_seqs=n_seqs, | |
batch_size=batch_size, | |
pad_id=pad_id, | |
patch_lengths=patch_lengths, | |
max_length=max_length, | |
) | |
batches = [] | |
n_nonpad = 0 | |
n_nonmask = 0 | |
for batch in iterator: | |
assert batch.x.shape[0] == batch_size | |
assert batch.x.shape[1] == max_length | |
n_nonpad += (batch.x != pad_id).sum() | |
if batch.mask is None: | |
n_nonmask += batch.x.size | |
else: | |
n_nonmask += batch.mask.sum() | |
batches.append(batch) | |
assert len(batches) == 3 | |
# max_length - 1 is due to chopping off the last byte for | |
# having a y target | |
assert n_nonpad == n_nonmask == (max_length - 1) * n_seqs | |
def test_short_batch_correctness_patches(): | |
# First patch length is forced to be 1 | |
# Total=48 | |
patch_lengths = [1, 11, 12, 12, 12] | |
# Recall: This is in terms of bytes | |
max_length = 1024 | |
# Recall: This is in terms of patches | |
seq_len = 5 | |
n_seqs = 12 | |
batch_size = 4 | |
pad_id = 0 | |
iterator = create_patches_iter( | |
seq_len=seq_len, | |
n_seqs=n_seqs, | |
batch_size=batch_size, | |
pad_id=pad_id, | |
patch_lengths=patch_lengths, | |
max_length=max_length, | |
) | |
batches = [] | |
n_nonpad = 0 | |
n_nonmask = 0 | |
for batch in iterator: | |
assert batch.x.shape[0] == batch_size | |
assert batch.x.shape[1] == max_length | |
n_nonpad += (batch.x != pad_id).sum() | |
if batch.mask is None: | |
n_nonmask += batch.x.size | |
else: | |
n_nonmask += batch.mask.sum() | |
batches.append(batch) | |
assert len(batches) == 3 | |
# We'll still always have one byte chopped off the end | |
assert n_nonpad == n_nonmask == ((sum(patch_lengths) - 1) * n_seqs) | |
def test_long_batch_correctness_patches(): | |
# First patch length is forced to be 1 | |
# Total=48 | |
patch_lengths = [1, 255, 256, 256, 1024] | |
# Recall: This is in terms of bytes | |
max_length = 1024 | |
# Recall: This is in terms of patches | |
seq_len = 5 | |
n_seqs = 12 | |
batch_size = 4 | |
pad_id = 0 | |
iterator = create_patches_iter( | |
seq_len=seq_len, | |
n_seqs=n_seqs, | |
batch_size=batch_size, | |
pad_id=pad_id, | |
patch_lengths=patch_lengths, | |
max_length=max_length, | |
) | |
batches = [] | |
n_nonpad = 0 | |
n_nonmask = 0 | |
for batch in iterator: | |
assert batch.x.shape[0] == batch_size | |
assert batch.x.shape[1] == max_length | |
n_nonpad += (batch.x != pad_id).sum() | |
if batch.mask is None: | |
n_nonmask += batch.x.size | |
else: | |
n_nonmask += batch.mask.sum() | |
batches.append(batch) | |
assert len(batches) == 3 | |
# No chop here since the next byte is available | |
assert n_nonpad == n_nonmask == max_length * n_seqs | |
def test_merge_patch_seq_masks(): | |
batch_size = 4 | |
seq_len = 1024 | |
masks = [] | |
masks.append([True] * 1025) | |
masks.append([True] * 512) | |
masks.append([True] * 256) | |
masks.append([True] * 10) | |
expected_mask = np.zeros((batch_size, seq_len), dtype=bool) | |
expected_mask[0, :] = True | |
expected_mask[1, :511] = True | |
expected_mask[2, :255] = True | |
expected_mask[3, :9] = True | |
merged_mask = _merge_patch_seq_masks(batch_size, seq_len, masks) | |
assert (merged_mask == expected_mask).all() | |
with pytest.raises(AssertionError): | |
masks = [] | |
masks.append([True] * 1024) | |
masks.append([True] * 512) | |
masks.append([True] * 256) | |
masks.append([True] * 10) | |
_merge_patch_seq_masks(batch_size, seq_len, masks) | |