Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import os | |
import re | |
from logging import getLogger | |
from typing import Any, Generator | |
import fsspec | |
import pyarrow as pa | |
# pyarrow needs the initialization from this import | |
import pyarrow.dataset # pyright: ignore | |
import s3fs | |
from pydantic import BaseModel, ConfigDict | |
from bytelatent import ByteLatentError | |
from bytelatent.data.data_types import BltExample | |
from bytelatent.data.file_util import get_fs | |
from bytelatent.data.iterators.abstract_iterator import ( | |
PydanticIteratorState, | |
StatefulIterator, | |
) | |
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text | |
logger = getLogger(__name__) | |
class ArrowFileIteratorState(PydanticIteratorState): | |
model_config = ConfigDict(extra="forbid") | |
file_path: str | None | |
row_num: int | |
num_workers: int | |
worker_id: int | |
preprocess_dir: str | None | |
dataset_files: list[str] | None | |
entropy_model_name: str | None | |
arrow_batch_size: int = 100 | |
s3_profile: str | None | |
filesystem_type: str | None = None | |
file_format: str | |
def build(self) -> "ArrowFileIterator": | |
arrow_file = ArrowFileIterator( | |
file_path=self.file_path, | |
worker_id=self.worker_id, | |
num_workers=self.num_workers, | |
preprocess_dir=self.preprocess_dir, | |
entropy_model_name=self.entropy_model_name, | |
arrow_batch_size=self.arrow_batch_size, | |
dataset_files=self.dataset_files, | |
s3_profile=self.s3_profile, | |
filesystem_type=self.filesystem_type, | |
file_format=self.file_format, | |
) | |
if self.row_num != 0: | |
arrow_file._set_row_num(self.row_num) | |
return arrow_file | |
def shard_sort_key(file: str): | |
assert isinstance(file, str) | |
match = re.search(r".+\.shard_([0-9]+)\.arrow", file) | |
shard_number = int(match.group(1)) | |
return shard_number | |
def maybe_truncate_string(text: str, max_length: int): | |
if len(text) <= max_length: | |
return text | |
else: | |
return text[:max_length] + "..." | |
class ArrowFileIterator(StatefulIterator): | |
def __init__( | |
self, | |
*, | |
file_path: str | None, | |
worker_id: int, | |
num_workers: int, | |
preprocess_dir: str | None, | |
entropy_model_name: str | None, | |
arrow_batch_size: int, | |
dataset_files: list[str] | None = None, | |
s3_profile: str | None = None, | |
filesystem_type: str | None = None, | |
file_format: str = "arrow", | |
): | |
assert 0 <= worker_id < num_workers, (worker_id, num_workers) | |
if file_path is None and dataset_files is None: | |
raise ByteLatentError("file_path and dataset_files cannot both be None") | |
self.row_num = 0 | |
self.iter_id = 0 | |
self.batch_iterator = None | |
self.batch_to_consume = None | |
self.dataset = None | |
self.file_path = file_path | |
self.worker_id = worker_id | |
self.num_workers = num_workers | |
self.preprocess_dir = preprocess_dir | |
self.entropy_model_name = entropy_model_name | |
self.arrow_batch_size = arrow_batch_size | |
self.s3_profile = s3_profile | |
self.filesystem_type = filesystem_type | |
self.file_format = file_format | |
self.fs = None | |
if self.filesystem_type is not None: | |
if self.filesystem_type == "file": | |
self.fs = fsspec.filesystem("file") | |
elif self.filesystem_type == "s3": | |
self.fs = fsspec.filesystem("s3", profile=s3_profile) | |
else: | |
raise ValueError("Unknown filesystem") | |
logger.info("Arrow iterator using fs=%s", self.fs) | |
if dataset_files is None: | |
assert ( | |
file_path is not None | |
), "Must specify file_Path if dataset_files is None" | |
if file_format == "json": | |
if self.fs is None: | |
self.fs = get_fs(file_path, s3_profile=s3_profile) | |
if isinstance(self.fs, s3fs.S3FileSystem): | |
self.filesystem_type = "s3" | |
else: | |
self.filesystem_type = "file" | |
self.dataset_files = [file_path] | |
else: | |
# Prepare arrow shards | |
jsonl_file = file_path | |
parts = re.match( | |
r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) | |
) | |
assert parts is not None | |
dataset = parts.group(1) | |
data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) | |
data_dir_with_glob = os.path.join( | |
data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" | |
) | |
if self.fs is None: | |
self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) | |
if isinstance(self.fs, s3fs.S3FileSystem): | |
self.filesystem_type = "s3" | |
else: | |
self.filesystem_type = "file" | |
shard_files = self.fs.glob(data_dir_with_glob) | |
for s in shard_files: | |
complete_file = os.path.join( | |
data_dir, f"{os.path.basename(s)}.complete" | |
) | |
if not self.fs.exists(complete_file): | |
raise ValueError(f"Missing .complete for input file: {s}") | |
shard_files = sorted(shard_files, key=shard_sort_key) | |
if len(shard_files) == 0: | |
raise ByteLatentError( | |
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" | |
) | |
self.dataset_files = [f for f in shard_files] | |
else: | |
self.preprocess_dir = None | |
self.dataset_files = dataset_files | |
if dataset_files[0].startswith("s3://"): | |
for f in dataset_files: | |
assert f.startswith("s3://") | |
if self.fs is None: | |
self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) | |
if isinstance(self.fs, s3fs.S3FileSystem): | |
self.filesystem_type = "s3" | |
else: | |
self.filesystem_type = "file" | |
def get_state(self) -> ArrowFileIteratorState: | |
return ArrowFileIteratorState( | |
file_path=self.file_path, | |
row_num=self.row_num, | |
worker_id=self.worker_id, | |
num_workers=self.num_workers, | |
preprocess_dir=self.preprocess_dir, | |
entropy_model_name=self.entropy_model_name, | |
arrow_batch_size=self.arrow_batch_size, | |
dataset_files=self.dataset_files, | |
s3_profile=self.s3_profile, | |
filesystem_type=self.filesystem_type, | |
file_format=self.file_format, | |
) | |
def create_iter( | |
self, | |
) -> Generator[BltExample, Any, None]: | |
if self.dataset is None: | |
if isinstance(self.fs, s3fs.core.S3FileSystem): | |
filesystem = self.fs | |
else: | |
filesystem = None | |
self.dataset = pa.dataset.dataset( | |
self.dataset_files, format=self.file_format, filesystem=filesystem | |
) | |
self.iter_id += 1 | |
if self.batch_to_consume is not None: | |
batch_columns: dict[str, list] = self.batch_to_consume | |
self.batch_to_consume = None | |
if self.file_format == "arrow": | |
sample_ids = batch_columns["sample_id"] | |
texts = batch_columns["text"] | |
entropies = batch_columns["entropies"] | |
elif self.file_format == "json": | |
# This data hasn't been preprocessed to a uniform format, | |
# so we have to do it now and omit entropies | |
sample_ids = batch_columns[get_id_key(batch_columns)] | |
texts = get_text(batch_columns) | |
entropies = None | |
else: | |
raise ValueError(f"Unknown file format: {self.file_format}") | |
for i in range(len(sample_ids)): | |
out = BltExample( | |
sample_id=str(sample_ids[i]), | |
entropies=entropies[i] if entropies is not None else None, | |
text=texts[i], | |
tokens=None, | |
mask=None, | |
patch_lengths=None, | |
) | |
self.row_num += 1 | |
if (self.row_num - 1) % self.num_workers == self.worker_id: | |
yield out | |
self.batch_iterator = self.dataset.to_batches( | |
batch_size=self.arrow_batch_size, | |
# We have large files in GBs, no need to readahead | |
fragment_readahead=1, | |
# Don't readahead in case batches are huge (e.g., books) | |
batch_readahead=1, | |
) | |
for batch in self.batch_iterator: | |
batch_columns = batch.to_pydict() | |
if self.file_format == "arrow": | |
sample_ids = batch_columns["sample_id"] | |
texts = batch_columns["text"] | |
entropies = batch_columns["entropies"] | |
elif self.file_format == "json": | |
# This data hasn't been preprocessed to a uniform format, | |
# so we have to do it now and omit entropies | |
sample_ids = batch_columns[get_id_key(batch_columns)] | |
texts = get_text(batch_columns) | |
entropies = None | |
else: | |
raise ValueError(f"Unknown file format: {self.file_format}") | |
for i in range(len(sample_ids)): | |
out = BltExample( | |
sample_id=str(sample_ids[i]), | |
entropies=entropies[i] if entropies is not None else None, | |
text=texts[i], | |
tokens=None, | |
mask=None, | |
patch_lengths=None, | |
) | |
self.row_num += 1 | |
if (self.row_num - 1) % self.num_workers == self.worker_id: | |
yield out | |
def _set_row_num(self, target_row_num: int): | |
data_str = maybe_truncate_string(str(self.dataset_files), 200) | |
logger.info(f"Setting arrow position to {target_row_num} for {data_str}") | |
if target_row_num is None or target_row_num == 0: | |
self.row_num = 0 | |
self.dataset = None | |
self.batch_iterator = None | |
self.batch_to_consume = None | |
else: | |
if isinstance(self.fs, s3fs.core.S3FileSystem): | |
filesystem = self.fs | |
else: | |
filesystem = None | |
self.dataset = pa.dataset.dataset( | |
self.dataset_files, format="arrow", filesystem=filesystem | |
) | |
self.batch_iterator = self.dataset.to_batches( | |
batch_size=self.arrow_batch_size | |
) | |
curr_remaining = target_row_num | |
for batch in self.batch_iterator: | |
if len(batch) > curr_remaining: | |
batch_columns: dict[str, list] = batch.to_pydict() | |
if self.file_format == "arrow": | |
leftover_sample_ids = batch_columns["sample_id"][ | |
curr_remaining: | |
] | |
leftover_entropies = batch_columns["entropies"][curr_remaining:] | |
leftover_texts = batch_columns["text"][curr_remaining:] | |
elif self.file_format == "json": | |
leftover_sample_ids = batch_columns[get_id_key(batch_columns)][ | |
curr_remaining: | |
] | |
leftover_entropies = None | |
leftover_texts = get_text(batch_columns)[curr_remaining:] | |
else: | |
raise ValueError(f"Unknown file format: {self.file_format}") | |
batch_columns["sample_id"] = leftover_sample_ids | |
batch_columns["entropies"] = leftover_entropies | |
batch_columns["text"] = leftover_texts | |
self.batch_to_consume = batch_columns | |
break | |
elif len(batch) == curr_remaining: | |
# We are exactly at the end of the batch, | |
# so the next batch is the right spot | |
break | |
else: | |
curr_remaining -= len(batch) | |
self.row_num = target_row_num | |
data_str = maybe_truncate_string(str(self.dataset_files), 200) | |
logger.info( | |
f"Finished setting arrow position to {target_row_num} for {data_str}" | |
) | |