from typing import Any, Callable, Dict, List, Optional from air_benchmark import AIRBench, Retriever from llama_index.core import VectorStoreIndex from llama_index.core.node_parser import SentenceSplitter from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.llms.openai import OpenAI from llama_index.retrievers.bm25 import BM25Retriever from llama_index.core.retrievers import QueryFusionRetriever from llama_index.core.schema import Document, NodeWithScore def create_retriever_fn(documents: List[Document], top_k: int) -> Callable[[str], List[NodeWithScore]]: # IMPORTANT: if you don't use a llama-index node parser/splitter, you need to ensure # that node.ref_doc_id points to the correct parent document id. # This is used to line up the corpus document id for evaluation nodes = SentenceSplitter(chunk_size=1024, chunk_overlap=128)(documents) vector_index = VectorStoreIndex( nodes=nodes, embed_model=OpenAIEmbedding(model_name="text-embedding-ada-002") ) vector_retriever = vector_index.as_retriever(similarity_top_k=top_k) bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k) retriever = QueryFusionRetriever( [vector_retriever, bm25_retriever], similarity_top_k=top_k, num_queries=3, mode="dist_based_score", llm=OpenAI(model="gpt-3.5-turbo", temperature=0.1) ) def _retriever(query: str) -> List[NodeWithScore]: return retriever.retrieve(query) return _retriever class LlamaRetriever(Retriever): def __init__( self, name: str, create_retriever_fn: Callable[[List[Document], int], Callable[[str], List[NodeWithScore]]], search_top_k: int = 1000, ) -> None: self.name = name self.search_top_k self.create_retriever_fn = create_retriever_fn def __str__(self): return self.name def __call__( self, corpus: Dict[str, Dict[str, Any]], queries: Dict[str, str], **kwargs, ) -> Dict[str, Dict[str, float]]: """ Retrieve relevant documents for each query """ documents = [] for doc_id, doc in corpus.items(): text = doc.pop("text") assert text is not None documents.append(Document(id_=doc_id, text=text, metadata={**doc})) retriever = self.create_retriever_fn(documents) query_ids = list(queries.keys()) results = {qid: {} for qid in query_ids} for qid in query_ids: query = queries[qid] if isinstance(query, list): # take from mteb: # https://github.com/embeddings-benchmark/mteb/blob/main/mteb/evaluation/evaluators/RetrievalEvaluator.py#L403 query = "; ".join(query) nodes = retriever(query) for node in nodes: # ref_doc_id should point to corpus document id results[qid][node.node.ref_doc_id] = node.score return results retriever = LlamaRetriever("vector_bm25_fusion", create_retriever_fn) evaluation = AIRBench( benchmark_version="AIR-Bench_24.04", task_types=["long-doc"], # remove this line if you want to evaluate on all tasks domains=["arxiv"], # remove this line if you want to evaluate on all domains languages=["en"], # remove this line if you want to evaluate on all languages # cache_dir="~/.air_bench/" # path to the cache directory (**NEED ~52GB FOR FULL BENCHMARK**) ) evaluation.run( retriever, output_dir="./llama_results", # path to the output directory, default is "./search_results" overwrite=True # set to True if you want to overwrite the existing results )