Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import time | |
import fsspec | |
import jsonlines | |
import numpy as np | |
import pyarrow as pa | |
import torch | |
import typer | |
from rich.progress import Progress, TextColumn | |
from bytelatent.data.file_util import get_fs | |
from bytelatent.data.patcher import calculate_entropies | |
from bytelatent.entropy_model import load_entropy_model | |
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs | |
def get_id_key(doc: dict) -> int: | |
""" | |
We need a reliable way to ensure that samples from jsonl | |
and arrow are the same, but there is no unique id field, | |
so derive the best possible | |
""" | |
if "sample_id" in doc: | |
return "sample_id" | |
elif "title" in doc: | |
return "title" | |
elif "qid" in doc: | |
return "qid" | |
elif "paper_id" in doc: | |
return "paper_id" | |
elif "path" in doc: | |
return "path" | |
elif "url" in doc: | |
return "url" | |
elif "id" in doc: | |
return "id" | |
else: | |
raise ValueError(f"Could not find a id key from: {doc.keys()}") | |
def get_id_from_doc(doc: dict) -> int: | |
""" | |
We need a reliable way to ensure that samples from jsonl | |
and arrow are the same, but there is no unique id field, | |
so derive the best possible | |
""" | |
return str(doc[get_id_key(doc)]) | |
def get_text(doc: dict): | |
if "text" in doc: | |
text = doc["text"] | |
elif "content" in doc: | |
text = doc["content"] | |
else: | |
raise ValueError(f"Could not find a text key from: {doc.keys()}") | |
return text | |
def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str): | |
with fs.open(path) as f: | |
reader = jsonlines.Reader(f) | |
yield from reader | |
def main( | |
input_file: str, | |
output_file: str, | |
patching_device: str = "cuda", | |
log_step: int = 10_000, | |
entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint", | |
entropy_model_state_dict_path: str = "public_data/entropy_model.pth", | |
bpe_tokenizer_path: str = "public_data/tokenizer.model", | |
dry_run: bool = False, | |
s3_profile: str | None = None, | |
): | |
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}") | |
print("Loading entropy model", entropy_model_checkpoint_dir) | |
input_fs = get_fs(input_file, s3_profile=s3_profile) | |
input_doc_iterator = jsonl_file_iterator(input_fs, input_file) | |
if dry_run: | |
return | |
entropy_model = load_entropy_model( | |
entropy_model_checkpoint_dir, | |
entropy_model_state_dict_path, | |
device=patching_device, | |
) | |
print("Creating patcher") | |
patching_batch_size = 32 | |
print("Creating tokenizer") | |
tokenizer_args = TokenizerArgs( | |
name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path} | |
) | |
tokenizer = tokenizer_args.build() | |
step = 0 | |
print("starting") | |
start_time = time.time() | |
patch_time = 0 | |
entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False) | |
sample_id_field = pa.field("sample_id", pa.string(), nullable=False) | |
text_field = pa.field("text", pa.string(), nullable=False) | |
schema = pa.schema([sample_id_field, text_field, entropy_field]) | |
arrow_batch_size = 1_000 | |
output_fs = get_fs(output_file, s3_profile=s3_profile) | |
try: | |
with output_fs.open(output_file, "wb") as sink: | |
with pa.ipc.new_file(sink, schema) as writer: | |
id_buffer = [] | |
entropies_buffer = [] | |
text_buffer = [] | |
with Progress( | |
*Progress.get_default_columns(), | |
TextColumn("Completed: {task.completed}"), | |
) as progress: | |
task = progress.add_task( | |
"[green]Calculating entropies...", total=None | |
) | |
for doc in input_doc_iterator: | |
sample_id = get_id_from_doc(doc) | |
text = get_text(doc) | |
tokens = torch.tensor(tokenizer.encode(text)) | |
patch_start = time.time() | |
scores, _ = calculate_entropies( | |
tokens, | |
entropy_model, | |
patching_batch_size, | |
patching_device, | |
) | |
entropies_buffer.append( | |
np.array(scores.tolist(), dtype=np.float16) | |
) | |
id_buffer.append(sample_id) | |
text_buffer.append(text) | |
if len(entropies_buffer) == arrow_batch_size: | |
batch = pa.record_batch( | |
{ | |
"entropies": entropies_buffer, | |
"sample_id": id_buffer, | |
"text": text_buffer, | |
}, | |
schema, | |
) | |
writer.write(batch) | |
entropies_buffer = [] | |
id_buffer = [] | |
text_buffer = [] | |
patch_time += time.time() - patch_start | |
step += 1 | |
if step % log_step == 0: | |
print("Completed steps:", step) | |
progress.update(task, advance=1) | |
if len(entropies_buffer) > 0: | |
# Write last things | |
batch = pa.record_batch( | |
{ | |
"entropies": entropies_buffer, | |
"sample_id": id_buffer, | |
"text": text_buffer, | |
}, | |
schema, | |
) | |
writer.write(batch) | |
entropies_buffer = [] | |
id_buffer = [] | |
text_buffer = [] | |
output_fs.touch(f"{output_file}.complete") | |
except: | |
if output_fs.exists(output_file): | |
output_fs.rm(output_file) | |
raise | |
elapsed = time.time() - start_time | |
print("steps", step) | |
print("done in:", elapsed) | |
if __name__ == "__main__": | |
typer.run(main) | |