import os |
import argparse |
import csv |
import logging |
import pickle |
import numpy as np |
import torch |
import transformers |
import src.slurm |
import src.contriever |
import src.utils |
import src.data |
import src.normalize_text |
def embed_passages(args, passages, model, tokenizer): |
total = 0 |
allids, allembeddings = [], [] |
batch_ids, batch_text = [], [] |
with torch.no_grad(): |
for k, p in enumerate(passages): |
batch_ids.append(p["id"]) |
if args.no_title or not "title" in p: |
text = p["text"] |
else: |
text = p["title"] + " " + p["text"] |
if args.lowercase: |
text = text.lower() |
if args.normalize_text: |
text = src.normalize_text.normalize(text) |
batch_text.append(text) |
if len(batch_text) == args.per_gpu_batch_size or k == len(passages) - 1: |
encoded_batch = tokenizer.batch_encode_plus( |
batch_text, |
return_tensors="pt", |
max_length=args.passage_maxlength, |
padding=True, |
truncation=True, |
) |
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()} |
embeddings = model(**encoded_batch) |
embeddings = embeddings.cpu() |
total += len(batch_ids) |
allids.extend(batch_ids) |
allembeddings.append(embeddings) |
batch_text = [] |
batch_ids = [] |
if k % 100000 == 0 and k > 0: |
print(f"Encoded passages {total}") |
allembeddings = torch.cat(allembeddings, dim=0).numpy() |
return allids, allembeddings |
def main(args): |
model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path) |
print(f"Model loaded from {args.model_name_or_path}.", flush=True) |
model.eval() |
model = model.cuda() |
if not args.no_fp16: |
model = model.half() |
passages = src.data.load_passages(args.passages) |
shard_size = len(passages) // args.num_shards |
start_idx = args.shard_id * shard_size |
end_idx = start_idx + shard_size |
if args.shard_id == args.num_shards - 1: |
end_idx = len(passages) |
passages = passages[start_idx:end_idx] |
print(f"Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}.") |
allids, allembeddings = embed_passages(args, passages, model, tokenizer) |
save_file = os.path.join(args.output_dir, args.prefix + f"_{args.shard_id:02d}") |
os.makedirs(args.output_dir, exist_ok=True) |
print(f"Saving {len(allids)} passage embeddings to {save_file}.") |
with open(save_file, mode="wb") as f: |
pickle.dump((allids, allembeddings), f) |
print(f"Total passages processed {len(allids)}. Written to {save_file}.") |
if __name__ == "__main__": |
parser = argparse.ArgumentParser() |
parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)") |
parser.add_argument("--output_dir", type=str, default="wikipedia_embeddings", help="dir path to save embeddings") |
parser.add_argument("--prefix", type=str, default="passages", help="prefix path to save embeddings") |
parser.add_argument("--shard_id", type=int, default=0, help="Id of the current shard") |
parser.add_argument("--num_shards", type=int, default=1, help="Total number of shards") |
parser.add_argument( |
"--per_gpu_batch_size", type=int, default=512, help="Batch size for the passage encoder forward pass" |
) |
parser.add_argument("--passage_maxlength", type=int, default=512, help="Maximum number of tokens in a passage") |
parser.add_argument( |
"--model_name_or_path", type=str, help="path to directory containing model weights and config file" |
) |
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32") |
parser.add_argument("--no_title", action="store_true", help="title not added to the passage body") |
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding") |
parser.add_argument("--normalize_text", action="store_true", help="lowercase text before encoding") |
args = parser.parse_args() |
src.slurm.init_distributed_mode(args) |
main(args) |