File size: 3,859 Bytes
81bb72e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from typing import Any, Callable, Dict, List, Optional

from air_benchmark import AIRBench, Retriever
from llama_index.core import VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core.schema import Document, NodeWithScore


def create_retriever_fn(documents: List[Document], top_k: int) -> Callable[[str], List[NodeWithScore]]:
    # IMPORTANT: if you don't use a llama-index node parser/splitter, you need to ensure
    # that node.ref_doc_id points to the correct parent document id.
    # This is used to line up the corpus document id for evaluation
    nodes = SentenceSplitter(chunk_size=1024, chunk_overlap=128)(documents)

    vector_index = VectorStoreIndex(
        nodes=nodes, 
        embed_model=OpenAIEmbedding(model_name="text-embedding-ada-002")
    )
    vector_retriever = vector_index.as_retriever(similarity_top_k=top_k)
    bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k)

    retriever = QueryFusionRetriever(
        [vector_retriever, bm25_retriever], 
        similarity_top_k=top_k, 
        num_queries=3, 
        mode="dist_based_score",
        llm=OpenAI(model="gpt-3.5-turbo", temperature=0.1)
    )

    def _retriever(query: str) -> List[NodeWithScore]:
        return retriever.retrieve(query)

    return _retriever



class LlamaRetriever(Retriever):
    def __init__(
        self, 
        name: str, 
        create_retriever_fn: Callable[[List[Document], int], Callable[[str], List[NodeWithScore]]], 
        search_top_k: int = 1000,
    ) -> None:
        self.name = name
        self.search_top_k
        self.create_retriever_fn = create_retriever_fn

    def __str__(self):
        return self.name

    def __call__(
        self,
        corpus: Dict[str, Dict[str, Any]],
        queries: Dict[str, str],
        **kwargs,
    ) -> Dict[str, Dict[str, float]]:
        """
        Retrieve relevant documents for each query
        """
        
        documents = []
        for doc_id, doc in corpus.items():
            text = doc.pop("text")
            assert text is not None
            documents.append(Document(id_=doc_id, text=text, metadata={**doc}))
        
        retriever = self.create_retriever_fn(documents)

        query_ids = list(queries.keys())
        results = {qid: {} for qid in query_ids}
        for qid in query_ids:
            query = queries[qid]
            if isinstance(query, list):
                # take from mteb: 
                # https://github.com/embeddings-benchmark/mteb/blob/main/mteb/evaluation/evaluators/RetrievalEvaluator.py#L403
                query = "; ".join(query)
            
            nodes = retriever(query)
            for node in nodes:
                # ref_doc_id should point to corpus document id
                results[qid][node.node.ref_doc_id] = node.score
        
        return results
        


retriever = LlamaRetriever("vector_bm25_fusion", create_retriever_fn)

evaluation = AIRBench(
    benchmark_version="AIR-Bench_24.04",
    task_types=["long-doc"],      # remove this line if you want to evaluate on all tasks
    domains=["arxiv"],            # remove this line if you want to evaluate on all domains
    languages=["en"],             # remove this line if you want to evaluate on all languages
    # cache_dir="~/.air_bench/"     # path to the cache directory (**NEED ~52GB FOR FULL BENCHMARK**)
)

evaluation.run(
    retriever,
    output_dir="./llama_results",   # path to the output directory, default is "./search_results"
    overwrite=True             # set to True if you want to overwrite the existing results
)