File size: 5,165 Bytes
30ffb9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from typing import List, Tuple, Dict, Any
import time
from tqdm.notebook import tqdm
from rich import print

from retrieval_evaluation import calc_hit_rate_scores, calc_mrr_scores, record_results, add_params
from llama_index.finetuning import EmbeddingQAFinetuneDataset
from weaviate_interface import WeaviateClient


def retrieval_evaluation(dataset: EmbeddingQAFinetuneDataset, 
                         class_name: str, 
                         retriever: WeaviateClient,
                         retrieve_limit: int=5,
                         chunk_size: int=256,
                         hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef'],
                         display_properties: List[str]=['doc_id', 'guest', 'content'],
                         dir_outpath: str='./eval_results',
                         include_miss_info: bool=False,
                         user_def_params: Dict[str,Any]=None
                         ) -> Dict[str, str|int|float]:
    '''
    Given a dataset and a retriever evaluate the performance of the retriever. Returns a dict of kw and vector
    hit rates and mrr scores. If inlude_miss_info is True, will also return a list of kw and vector responses 
    and their associated queries that did not return a hit, for deeper analysis. Text file with results output
    is automatically saved in the dir_outpath directory.

    Args:
    -----
    dataset: EmbeddingQAFinetuneDataset
        Dataset to be used for evaluation
    class_name: str
        Name of Class on Weaviate host to be used for retrieval
    retriever: WeaviateClient
        WeaviateClient object to be used for retrieval 
    retrieve_limit: int=5
        Number of documents to retrieve from Weaviate host
    chunk_size: int=256
        Number of tokens used to chunk text. This value is purely for results 
        recording purposes and does not affect results. 
    display_properties: List[str]=['doc_id', 'content']
        List of properties to be returned from Weaviate host for display in response
    dir_outpath: str='./eval_results'
        Directory path for saving results.  Directory will be created if it does not
        already exist. 
    include_miss_info: bool=False
        Option to include queries and their associated kw and vector response values
        for queries that are "total misses"
    user_def_params : dict=None
        Option for user to pass in a dictionary of user-defined parameters and their values.
    '''

    results_dict = {'n':retrieve_limit, 
                    'Retriever': retriever.model_name_or_path, 
                    'chunk_size': chunk_size,
                    'kw_hit_rate': 0,
                    'kw_mrr': 0,
                    'vector_hit_rate': 0,
                    'vector_mrr': 0,
                    'total_misses': 0,
                    'total_questions':0
                    }
    #add hnsw configs and user defined params (if any)
    results_dict = add_params(retriever, class_name, results_dict, user_def_params, hnsw_config_keys)
    
    start = time.perf_counter()
    miss_info = []
    for query_id, q in tqdm(dataset.queries.items(), 'Queries'):
        results_dict['total_questions'] += 1
        hit = False
        #make Keyword, Vector, and Hybrid calls to Weaviate host
        try:
            kw_response = retriever.keyword_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
            vector_response = retriever.vector_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
            
            #collect doc_ids and position of doc_ids to check for document matches
            kw_doc_ids = {result['doc_id']:i for i, result in enumerate(kw_response, 1)}
            vector_doc_ids = {result['doc_id']:i for i, result in enumerate(vector_response, 1)}
            
            #extract doc_id for scoring purposes
            doc_id = dataset.relevant_docs[query_id][0]
     
            #increment hit_rate counters and mrr scores
            if doc_id in kw_doc_ids:
                results_dict['kw_hit_rate'] += 1
                results_dict['kw_mrr'] += 1/kw_doc_ids[doc_id]
                hit = True
            if doc_id in vector_doc_ids:
                results_dict['vector_hit_rate'] += 1
                results_dict['vector_mrr'] += 1/vector_doc_ids[doc_id]
                hit = True
                
            # if no hits, let's capture that
            if not hit:
                results_dict['total_misses'] += 1
                miss_info.append({'query': q, 'kw_response': kw_response, 'vector_response': vector_response})
        except Exception as e:
            print(e)
            continue
    

    #use raw counts to calculate final scores
    calc_hit_rate_scores(results_dict)
    calc_mrr_scores(results_dict)
    
    end = time.perf_counter() - start
    print(f'Total Processing Time: {round(end/60, 2)} minutes')
    record_results(results_dict, chunk_size, dir_outpath=dir_outpath, as_text=True)
    
    if include_miss_info:
        return results_dict, miss_info
    return results_dict