|
import argparse |
|
import json |
|
import logging |
|
from collections.abc import Iterator |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
from datasets import load_dataset |
|
from more_itertools import batched |
|
from sentence_transformers import SentenceTransformer |
|
from tqdm import tqdm |
|
from transformers.tokenization_utils import PreTrainedTokenizer |
|
|
|
_SAVE_EVERY = 32 |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def featurize( |
|
dataset: Iterator[dict[str, str]], |
|
model: SentenceTransformer, |
|
output_dir: str, |
|
max_means: int, |
|
batch_size: int, |
|
text_key: str, |
|
) -> None: |
|
"""Make a directory and dump all kinds of data in it.""" |
|
output_dir_path = Path(output_dir) |
|
output_dir_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
largest_batch = max([int(x.stem.split("_")[1]) for x in list(output_dir_path.glob("*.json"))], default=0) |
|
if largest_batch: |
|
logger.info(f"Resuming from batch {largest_batch}, skipping previous batches.") |
|
|
|
texts = [] |
|
embeddings = [] |
|
dim = model.get_sentence_embedding_dimension() |
|
if dim is None: |
|
msg = "Model has no sentence embedding dimension." |
|
raise ValueError(msg) |
|
|
|
tokenizer: PreTrainedTokenizer = model.tokenizer |
|
|
|
i = 0 |
|
for i, batch in tqdm(enumerate(batched(dataset, n=batch_size))): |
|
if i * batch_size >= max_means: |
|
logger.info(f"Reached maximum number of means: {max_means}") |
|
break |
|
if largest_batch and i <= largest_batch: |
|
continue |
|
batch = [x[text_key] for x in batch] |
|
|
|
if not all(isinstance(x, str) for x in batch): |
|
msg = f"Detected non-string at batch: {i}" |
|
raise ValueError(msg) |
|
|
|
batch_embeddings = model.encode(batch, output_value="token_embeddings") |
|
for text, embedding in zip(batch, batch_embeddings, strict=False): |
|
texts.append(_truncate_text(tokenizer, text)) |
|
embeddings.append(embedding[1:-1].mean(axis=0).cpu().numpy()) |
|
if i and i % _SAVE_EVERY == 0: |
|
json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4) |
|
np.save(output_dir_path / f"feature_{i}.npy", embeddings) |
|
texts = [] |
|
embeddings = [] |
|
if texts: |
|
json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4) |
|
np.save(output_dir_path / f"feature_{i}.npy", embeddings) |
|
|
|
|
|
def _truncate_text(tokenizer: PreTrainedTokenizer, text: str) -> str: |
|
"""Truncate text to fit the tokenizer's maximum length.""" |
|
tokens = tokenizer.encode( |
|
text, |
|
truncation=True, |
|
max_length=tokenizer.model_max_length, |
|
) |
|
return tokenizer.decode(tokens, skip_special_tokens=True) |
|
|
|
|
|
def main() -> None: |
|
"""Main function to featurize texts using a sentence transformer.""" |
|
parser = argparse.ArgumentParser(description="Featurize texts using a sentence transformer.") |
|
parser.add_argument( |
|
"--model-name", |
|
type=str, |
|
default="baai/bge-base-en-v1.5", |
|
help="The model name for distillation (e.g., 'baai/bge-base-en-v1.5').", |
|
) |
|
parser.add_argument( |
|
"--output-dir", |
|
type=str, |
|
default=None, |
|
help="Directory to save the featurized texts.", |
|
) |
|
parser.add_argument( |
|
"--dataset-path", |
|
type=str, |
|
default="allenai/c4", |
|
help="The dataset path or name (e.g. 'allenai/c4').", |
|
) |
|
parser.add_argument( |
|
"--dataset-name", |
|
type=str, |
|
default="en", |
|
help="The dataset configuration name (e.g., 'en' for C4).", |
|
) |
|
parser.add_argument( |
|
"--dataset-split", |
|
type=str, |
|
default="train", |
|
help="The dataset split (e.g., 'train', 'validation').", |
|
) |
|
parser.add_argument( |
|
"--no-streaming", |
|
action="store_false", |
|
help="Disable streaming mode when loading the dataset.", |
|
) |
|
parser.add_argument( |
|
"--max-means", |
|
type=int, |
|
default=1000000, |
|
help="The maximum number of mean embeddings to generate.", |
|
) |
|
parser.add_argument( |
|
"--key", |
|
type=str, |
|
default="text", |
|
help="The key of the text field in the dataset to featurize (default: 'text').", |
|
) |
|
parser.add_argument( |
|
"--batch-size", |
|
type=int, |
|
default=32, |
|
help="Batch size to use for encoding the texts.", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
if args.output_dir is None: |
|
model_name = args.model_name.replace("/", "_") |
|
dataset_path = args.dataset_path.replace("/", "_") |
|
output_dir = f"{model_name}_{dataset_path}_featurized" |
|
else: |
|
output_dir = args.output_dir |
|
|
|
model = SentenceTransformer(args.model_name) |
|
dataset = load_dataset( |
|
args.dataset_path, |
|
name=args.dataset_name, |
|
split=args.dataset_split, |
|
streaming=args.no_streaming, |
|
) |
|
|
|
featurize(iter(dataset), model, output_dir, args.max_means, args.batch_size, args.key) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|