Sarthak
chore: moved tokenlearn as in internal package
4255a26
import argparse
import json
import logging
from collections.abc import Iterator
from pathlib import Path
import numpy as np
from datasets import load_dataset
from more_itertools import batched
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from transformers.tokenization_utils import PreTrainedTokenizer
_SAVE_EVERY = 32
logger = logging.getLogger(__name__)
def featurize(
dataset: Iterator[dict[str, str]],
model: SentenceTransformer,
output_dir: str,
max_means: int,
batch_size: int,
text_key: str,
) -> None:
"""Make a directory and dump all kinds of data in it."""
output_dir_path = Path(output_dir)
output_dir_path.mkdir(parents=True, exist_ok=True)
# Ugly hack
largest_batch = max([int(x.stem.split("_")[1]) for x in list(output_dir_path.glob("*.json"))], default=0)
if largest_batch:
logger.info(f"Resuming from batch {largest_batch}, skipping previous batches.")
texts = []
embeddings = []
dim = model.get_sentence_embedding_dimension()
if dim is None:
msg = "Model has no sentence embedding dimension."
raise ValueError(msg)
tokenizer: PreTrainedTokenizer = model.tokenizer
# Binding i in case the dataset is empty.
i = 0
for i, batch in tqdm(enumerate(batched(dataset, n=batch_size))):
if i * batch_size >= max_means:
logger.info(f"Reached maximum number of means: {max_means}")
break
if largest_batch and i <= largest_batch:
continue
batch = [x[text_key] for x in batch]
if not all(isinstance(x, str) for x in batch):
msg = f"Detected non-string at batch: {i}"
raise ValueError(msg)
batch_embeddings = model.encode(batch, output_value="token_embeddings") # type: ignore # Annoying
for text, embedding in zip(batch, batch_embeddings, strict=False):
texts.append(_truncate_text(tokenizer, text))
embeddings.append(embedding[1:-1].mean(axis=0).cpu().numpy())
if i and i % _SAVE_EVERY == 0:
json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4)
np.save(output_dir_path / f"feature_{i}.npy", embeddings)
texts = []
embeddings = []
if texts:
json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4)
np.save(output_dir_path / f"feature_{i}.npy", embeddings)
def _truncate_text(tokenizer: PreTrainedTokenizer, text: str) -> str:
"""Truncate text to fit the tokenizer's maximum length."""
tokens = tokenizer.encode(
text,
truncation=True,
max_length=tokenizer.model_max_length,
)
return tokenizer.decode(tokens, skip_special_tokens=True)
def main() -> None:
"""Main function to featurize texts using a sentence transformer."""
parser = argparse.ArgumentParser(description="Featurize texts using a sentence transformer.")
parser.add_argument(
"--model-name",
type=str,
default="baai/bge-base-en-v1.5",
help="The model name for distillation (e.g., 'baai/bge-base-en-v1.5').",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Directory to save the featurized texts.",
)
parser.add_argument(
"--dataset-path",
type=str,
default="allenai/c4",
help="The dataset path or name (e.g. 'allenai/c4').",
)
parser.add_argument(
"--dataset-name",
type=str,
default="en",
help="The dataset configuration name (e.g., 'en' for C4).",
)
parser.add_argument(
"--dataset-split",
type=str,
default="train",
help="The dataset split (e.g., 'train', 'validation').",
)
parser.add_argument(
"--no-streaming",
action="store_false",
help="Disable streaming mode when loading the dataset.",
)
parser.add_argument(
"--max-means",
type=int,
default=1000000,
help="The maximum number of mean embeddings to generate.",
)
parser.add_argument(
"--key",
type=str,
default="text",
help="The key of the text field in the dataset to featurize (default: 'text').",
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size to use for encoding the texts.",
)
args = parser.parse_args()
if args.output_dir is None:
model_name = args.model_name.replace("/", "_")
dataset_path = args.dataset_path.replace("/", "_")
output_dir = f"{model_name}_{dataset_path}_featurized"
else:
output_dir = args.output_dir
model = SentenceTransformer(args.model_name)
dataset = load_dataset(
args.dataset_path,
name=args.dataset_name,
split=args.dataset_split,
streaming=args.no_streaming,
)
featurize(iter(dataset), model, output_dir, args.max_means, args.batch_size, args.key)
if __name__ == "__main__":
main()