cheesyFishes commited on
Commit
81bb72e
1 Parent(s): 646ee65

Upload run_airbench.py

Browse files
Files changed (1) hide show
  1. run_airbench.py +105 -0
run_airbench.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional
2
+
3
+ from air_benchmark import AIRBench, Retriever
4
+ from llama_index.core import VectorStoreIndex
5
+ from llama_index.core.node_parser import SentenceSplitter
6
+ from llama_index.embeddings.openai import OpenAIEmbedding
7
+ from llama_index.llms.openai import OpenAI
8
+ from llama_index.retrievers.bm25 import BM25Retriever
9
+ from llama_index.core.retrievers import QueryFusionRetriever
10
+ from llama_index.core.schema import Document, NodeWithScore
11
+
12
+
13
+ def create_retriever_fn(documents: List[Document], top_k: int) -> Callable[[str], List[NodeWithScore]]:
14
+ # IMPORTANT: if you don't use a llama-index node parser/splitter, you need to ensure
15
+ # that node.ref_doc_id points to the correct parent document id.
16
+ # This is used to line up the corpus document id for evaluation
17
+ nodes = SentenceSplitter(chunk_size=1024, chunk_overlap=128)(documents)
18
+
19
+ vector_index = VectorStoreIndex(
20
+ nodes=nodes,
21
+ embed_model=OpenAIEmbedding(model_name="text-embedding-ada-002")
22
+ )
23
+ vector_retriever = vector_index.as_retriever(similarity_top_k=top_k)
24
+ bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k)
25
+
26
+ retriever = QueryFusionRetriever(
27
+ [vector_retriever, bm25_retriever],
28
+ similarity_top_k=top_k,
29
+ num_queries=3,
30
+ mode="dist_based_score",
31
+ llm=OpenAI(model="gpt-3.5-turbo", temperature=0.1)
32
+ )
33
+
34
+ def _retriever(query: str) -> List[NodeWithScore]:
35
+ return retriever.retrieve(query)
36
+
37
+ return _retriever
38
+
39
+
40
+
41
+ class LlamaRetriever(Retriever):
42
+ def __init__(
43
+ self,
44
+ name: str,
45
+ create_retriever_fn: Callable[[List[Document], int], Callable[[str], List[NodeWithScore]]],
46
+ search_top_k: int = 1000,
47
+ ) -> None:
48
+ self.name = name
49
+ self.search_top_k
50
+ self.create_retriever_fn = create_retriever_fn
51
+
52
+ def __str__(self):
53
+ return self.name
54
+
55
+ def __call__(
56
+ self,
57
+ corpus: Dict[str, Dict[str, Any]],
58
+ queries: Dict[str, str],
59
+ **kwargs,
60
+ ) -> Dict[str, Dict[str, float]]:
61
+ """
62
+ Retrieve relevant documents for each query
63
+ """
64
+
65
+ documents = []
66
+ for doc_id, doc in corpus.items():
67
+ text = doc.pop("text")
68
+ assert text is not None
69
+ documents.append(Document(id_=doc_id, text=text, metadata={**doc}))
70
+
71
+ retriever = self.create_retriever_fn(documents)
72
+
73
+ query_ids = list(queries.keys())
74
+ results = {qid: {} for qid in query_ids}
75
+ for qid in query_ids:
76
+ query = queries[qid]
77
+ if isinstance(query, list):
78
+ # take from mteb:
79
+ # https://github.com/embeddings-benchmark/mteb/blob/main/mteb/evaluation/evaluators/RetrievalEvaluator.py#L403
80
+ query = "; ".join(query)
81
+
82
+ nodes = retriever(query)
83
+ for node in nodes:
84
+ # ref_doc_id should point to corpus document id
85
+ results[qid][node.node.ref_doc_id] = node.score
86
+
87
+ return results
88
+
89
+
90
+
91
+ retriever = LlamaRetriever("vector_bm25_fusion", create_retriever_fn)
92
+
93
+ evaluation = AIRBench(
94
+ benchmark_version="AIR-Bench_24.04",
95
+ task_types=["long-doc"], # remove this line if you want to evaluate on all tasks
96
+ domains=["arxiv"], # remove this line if you want to evaluate on all domains
97
+ languages=["en"], # remove this line if you want to evaluate on all languages
98
+ # cache_dir="~/.air_bench/" # path to the cache directory (**NEED ~52GB FOR FULL BENCHMARK**)
99
+ )
100
+
101
+ evaluation.run(
102
+ retriever,
103
+ output_dir="./llama_results", # path to the output directory, default is "./search_results"
104
+ overwrite=True # set to True if you want to overwrite the existing results
105
+ )