| import os |
| import faiss |
| import json |
| import warnings |
| import numpy as np |
| from typing import cast, List, Dict |
| import shutil |
| import subprocess |
| import argparse |
| import torch |
| from tqdm import tqdm |
| |
| import datasets |
| from transformers import AutoTokenizer, AutoModel, AutoConfig |
|
|
|
|
| def load_model( |
| model_path: str, |
| use_fp16: bool = False |
| ): |
| model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) |
| model = AutoModel.from_pretrained(model_path, trust_remote_code=True) |
| model.eval() |
| model.cuda() |
| if use_fp16: |
| model = model.half() |
| tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True) |
|
|
| return model, tokenizer |
|
|
|
|
| def pooling( |
| pooler_output, |
| last_hidden_state, |
| attention_mask = None, |
| pooling_method = "mean" |
| ): |
| if pooling_method == "mean": |
| last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) |
| return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
| elif pooling_method == "cls": |
| return last_hidden_state[:, 0] |
| elif pooling_method == "pooler": |
| return pooler_output |
| else: |
| raise NotImplementedError("Pooling method not implemented!") |
|
|
|
|
| def load_corpus(corpus_path: str): |
| corpus = datasets.load_dataset( |
| 'json', |
| data_files=corpus_path, |
| split="train", |
| num_proc=4) |
| return corpus |
|
|
|
|
| class Index_Builder: |
| r"""A tool class used to build an index used in retrieval. |
| |
| """ |
| def __init__( |
| self, |
| retrieval_method, |
| model_path, |
| corpus_path, |
| save_dir, |
| max_length, |
| batch_size, |
| use_fp16, |
| pooling_method, |
| faiss_type=None, |
| embedding_path=None, |
| save_embedding=False, |
| faiss_gpu=False |
| ): |
| |
| self.retrieval_method = retrieval_method.lower() |
| self.model_path = model_path |
| self.corpus_path = corpus_path |
| self.save_dir = save_dir |
| self.max_length = max_length |
| self.batch_size = batch_size |
| self.use_fp16 = use_fp16 |
| self.pooling_method = pooling_method |
| self.faiss_type = faiss_type if faiss_type is not None else 'Flat' |
| self.embedding_path = embedding_path |
| self.save_embedding = save_embedding |
| self.faiss_gpu = faiss_gpu |
|
|
| self.gpu_num = torch.cuda.device_count() |
| |
| print(self.save_dir) |
| if not os.path.exists(self.save_dir): |
| os.makedirs(self.save_dir) |
| else: |
| if not self._check_dir(self.save_dir): |
| warnings.warn("Some files already exists in save dir and may be overwritten.", UserWarning) |
|
|
| self.index_save_path = os.path.join(self.save_dir, f"{self.retrieval_method}_{self.faiss_type}.index") |
|
|
| self.embedding_save_path = os.path.join(self.save_dir, f"emb_{self.retrieval_method}.memmap") |
|
|
| self.corpus = load_corpus(self.corpus_path) |
| |
| print("Finish loading...") |
| @staticmethod |
| def _check_dir(dir_path): |
| r"""Check if the dir path exists and if there is content. |
| |
| """ |
| |
| if os.path.isdir(dir_path): |
| if len(os.listdir(dir_path)) > 0: |
| return False |
| else: |
| os.makedirs(dir_path, exist_ok=True) |
| return True |
|
|
| def build_index(self): |
| r"""Constructing different indexes based on selective retrieval method. |
| |
| """ |
| if self.retrieval_method == "bm25": |
| self.build_bm25_index() |
| else: |
| self.build_dense_index() |
|
|
| def build_bm25_index(self): |
| """Building BM25 index based on Pyserini library. |
| |
| Reference: https://github.com/castorini/pyserini/blob/master/docs/usage-index.md#building-a-bm25-index-direct-java-implementation |
| """ |
|
|
| |
| self.save_dir = os.path.join(self.save_dir, "bm25") |
| os.makedirs(self.save_dir, exist_ok=True) |
| temp_dir = self.save_dir + "/temp" |
| temp_file_path = temp_dir + "/temp.jsonl" |
| os.makedirs(temp_dir) |
|
|
| |
| |
| |
| |
| |
| |
| shutil.copyfile(self.corpus_path, temp_file_path) |
| |
| print("Start building bm25 index...") |
| pyserini_args = ["--collection", "JsonCollection", |
| "--input", temp_dir, |
| "--index", self.save_dir, |
| "--generator", "DefaultLuceneDocumentGenerator", |
| "--threads", "1"] |
| |
| subprocess.run(["python", "-m", "pyserini.index.lucene"] + pyserini_args) |
|
|
| shutil.rmtree(temp_dir) |
| |
| print("Finish!") |
|
|
| def _load_embedding(self, embedding_path, corpus_size, hidden_size): |
| all_embeddings = np.memmap( |
| embedding_path, |
| mode="r", |
| dtype=np.float32 |
| ).reshape(corpus_size, hidden_size) |
| return all_embeddings |
|
|
| def _save_embedding(self, all_embeddings): |
| memmap = np.memmap( |
| self.embedding_save_path, |
| shape=all_embeddings.shape, |
| mode="w+", |
| dtype=all_embeddings.dtype |
| ) |
| length = all_embeddings.shape[0] |
| |
| save_batch_size = 10000 |
| if length > save_batch_size: |
| for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"): |
| j = min(i + save_batch_size, length) |
| memmap[i: j] = all_embeddings[i: j] |
| else: |
| memmap[:] = all_embeddings |
|
|
| def encode_all(self): |
| if self.gpu_num > 1: |
| print("Use multi gpu!") |
| self.encoder = torch.nn.DataParallel(self.encoder) |
| self.batch_size = self.batch_size * self.gpu_num |
|
|
| all_embeddings = [] |
|
|
| for start_idx in tqdm(range(0, len(self.corpus), self.batch_size), desc='Inference Embeddings:'): |
|
|
| |
| |
| |
| batch_data = self.corpus[start_idx:start_idx+self.batch_size]['contents'] |
|
|
| if self.retrieval_method == "e5": |
| batch_data = [f"passage: {doc}" for doc in batch_data] |
|
|
| inputs = self.tokenizer( |
| batch_data, |
| padding=True, |
| truncation=True, |
| return_tensors='pt', |
| max_length=self.max_length, |
| ).to('cuda') |
|
|
| inputs = {k: v.cuda() for k, v in inputs.items()} |
|
|
| |
| if "T5" in type(self.encoder).__name__: |
| |
| decoder_input_ids = torch.zeros( |
| (inputs['input_ids'].shape[0], 1), dtype=torch.long |
| ).to(inputs['input_ids'].device) |
| output = self.encoder( |
| **inputs, decoder_input_ids=decoder_input_ids, return_dict=True |
| ) |
| embeddings = output.last_hidden_state[:, 0, :] |
|
|
| else: |
| output = self.encoder(**inputs, return_dict=True) |
| embeddings = pooling(output.pooler_output, |
| output.last_hidden_state, |
| inputs['attention_mask'], |
| self.pooling_method) |
| if "dpr" not in self.retrieval_method: |
| embeddings = torch.nn.functional.normalize(embeddings, dim=-1) |
|
|
| embeddings = cast(torch.Tensor, embeddings) |
| embeddings = embeddings.detach().cpu().numpy() |
| all_embeddings.append(embeddings) |
|
|
| all_embeddings = np.concatenate(all_embeddings, axis=0) |
| all_embeddings = all_embeddings.astype(np.float32) |
|
|
| return all_embeddings |
|
|
| @torch.no_grad() |
| def build_dense_index(self): |
| """Obtain the representation of documents based on the embedding model(BERT-based) and |
| construct a faiss index. |
| """ |
| |
| if os.path.exists(self.index_save_path): |
| print("The index file already exists and will be overwritten.") |
| |
| self.encoder, self.tokenizer = load_model(model_path = self.model_path, |
| use_fp16 = self.use_fp16) |
| if self.embedding_path is not None: |
| hidden_size = self.encoder.config.hidden_size |
| corpus_size = len(self.corpus) |
| all_embeddings = self._load_embedding(self.embedding_path, corpus_size, hidden_size) |
| else: |
| all_embeddings = self.encode_all() |
| if self.save_embedding: |
| self._save_embedding(all_embeddings) |
| del self.corpus |
|
|
| |
| print("Creating index") |
| dim = all_embeddings.shape[-1] |
| faiss_index = faiss.index_factory(dim, self.faiss_type, faiss.METRIC_INNER_PRODUCT) |
| |
| if self.faiss_gpu: |
| co = faiss.GpuMultipleClonerOptions() |
| co.useFloat16 = True |
| co.shard = True |
| faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co) |
| if not faiss_index.is_trained: |
| faiss_index.train(all_embeddings) |
| faiss_index.add(all_embeddings) |
| faiss_index = faiss.index_gpu_to_cpu(faiss_index) |
| else: |
| if not faiss_index.is_trained: |
| faiss_index.train(all_embeddings) |
| faiss_index.add(all_embeddings) |
|
|
| faiss.write_index(faiss_index, self.index_save_path) |
| print("Finish!") |
|
|
|
|
| MODEL2POOLING = { |
| "e5": "mean", |
| "bge": "cls", |
| "contriever": "mean", |
| 'jina': 'mean' |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description = "Creating index.") |
|
|
| |
| parser.add_argument('--retrieval_method', type=str) |
| parser.add_argument('--model_path', type=str, default=None) |
| parser.add_argument('--corpus_path', type=str) |
| parser.add_argument('--save_dir', default= 'indexes/',type=str) |
|
|
| |
| parser.add_argument('--max_length', type=int, default=180) |
| parser.add_argument('--batch_size', type=int, default=512) |
| parser.add_argument('--use_fp16', default=False, action='store_true') |
| parser.add_argument('--pooling_method', type=str, default=None) |
| parser.add_argument('--faiss_type',default=None,type=str) |
| parser.add_argument('--embedding_path', default=None, type=str) |
| parser.add_argument('--save_embedding', action='store_true', default=False) |
| parser.add_argument('--faiss_gpu', default=False, action='store_true') |
| |
| args = parser.parse_args() |
|
|
| if args.pooling_method is None: |
| pooling_method = 'mean' |
| for k,v in MODEL2POOLING.items(): |
| if k in args.retrieval_method.lower(): |
| pooling_method = v |
| break |
| else: |
| if args.pooling_method not in ['mean','cls','pooler']: |
| raise NotImplementedError |
| else: |
| pooling_method = args.pooling_method |
|
|
|
|
| index_builder = Index_Builder( |
| retrieval_method = args.retrieval_method, |
| model_path = args.model_path, |
| corpus_path = args.corpus_path, |
| save_dir = args.save_dir, |
| max_length = args.max_length, |
| batch_size = args.batch_size, |
| use_fp16 = args.use_fp16, |
| pooling_method = pooling_method, |
| faiss_type = args.faiss_type, |
| embedding_path = args.embedding_path, |
| save_embedding = args.save_embedding, |
| faiss_gpu = args.faiss_gpu |
| ) |
| index_builder.build_index() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|