Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import numpy as np | |
| import pyarrow as pa | |
| # pyarrow needs the initialization from this import | |
| import pyarrow.dataset # pyright: ignore | |
| from bytelatent.constants import BLT_DATA | |
| from bytelatent.data.iterators.arrow_iterator import ( | |
| ArrowFileIterator, | |
| ArrowFileIteratorState, | |
| ) | |
| ENTROPY_MODEL = "transformer_100m" | |
| ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow") | |
| ARROW_TEST_DATA_2 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_01.arrow") | |
| def test_basic_arrow_file(): | |
| dataset = pa.dataset.dataset(ARROW_TEST_DATA_1, format="arrow") | |
| n_head = 1000 | |
| head_df = dataset.head(n_head).to_pandas() | |
| initial_state = ArrowFileIteratorState( | |
| file_path=None, | |
| num_workers=1, | |
| worker_id=0, | |
| preprocess_dir=None, | |
| entropy_model_name=ENTROPY_MODEL, | |
| dataset_files=[ARROW_TEST_DATA_1], | |
| row_num=0, | |
| arrow_batch_size=100, | |
| s3_profile=None, | |
| file_format="arrow", | |
| ) | |
| arrow_file = initial_state.build() | |
| start_state = arrow_file.get_state() | |
| assert start_state.row_num == initial_state.row_num | |
| sample_id = None | |
| for example in arrow_file.create_iter(): | |
| sample_id = example.sample_id | |
| assert head_df.iloc[0]["sample_id"] == sample_id | |
| break | |
| assert arrow_file.get_state().row_num == 1 | |
| arrow_file = initial_state.build() | |
| for example in arrow_file.create_iter(): | |
| assert example.sample_id == sample_id | |
| assert head_df.iloc[0]["sample_id"] == sample_id | |
| break | |
| # Test resume far enough in to be past the batch size of 100 | |
| resumed_state = ArrowFileIteratorState( | |
| file_path=None, | |
| num_workers=1, | |
| worker_id=0, | |
| preprocess_dir=None, | |
| entropy_model_name=ENTROPY_MODEL, | |
| dataset_files=[ARROW_TEST_DATA_1], | |
| row_num=251, | |
| arrow_batch_size=100, | |
| s3_profile=None, | |
| file_format="arrow", | |
| ) | |
| arrow_file = resumed_state.build() | |
| for example in arrow_file.create_iter(): | |
| assert example.sample_id == head_df.iloc[251]["sample_id"] | |
| assert arrow_file.get_state().row_num == 252 | |
| break | |
| world_rank = 1 | |
| world_size = 4 | |
| # Test World Size and Rank | |
| rank_state = ArrowFileIteratorState( | |
| file_path=None, | |
| num_workers=world_size, | |
| worker_id=world_rank, | |
| preprocess_dir=None, | |
| entropy_model_name=ENTROPY_MODEL, | |
| dataset_files=[ARROW_TEST_DATA_1], | |
| row_num=0, | |
| arrow_batch_size=100, | |
| s3_profile=None, | |
| file_format="arrow", | |
| ) | |
| arrow_file = rank_state.build() | |
| expected_ids = [] | |
| for i in range(n_head): | |
| if i % world_size == world_rank: | |
| expected_ids.append(head_df.iloc[i]["sample_id"]) | |
| print(len(expected_ids)) | |
| i = 0 | |
| for example in arrow_file.create_iter(): | |
| assert example.sample_id == expected_ids[i] | |
| i += 1 | |
| if i >= len(expected_ids): | |
| break | |
| def test_read_jsonl_from_arrow(): | |
| arrow_iterator = ArrowFileIterator( | |
| file_path="fixtures/test_docs.jsonl", | |
| num_workers=1, | |
| worker_id=0, | |
| preprocess_dir=None, | |
| entropy_model_name=None, | |
| file_format="json", | |
| arrow_batch_size=100, | |
| ) | |
| iterator = arrow_iterator.create_iter() | |
| for i, example in enumerate(iterator): | |
| assert example.sample_id == str(i) | |
| assert example.text == f"test_{i}" | |