File size: 15,062 Bytes
30ffb9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
#external files
from openai_interface import GPT_Turbo
from weaviate_interface import WeaviateClient
from llama_index.finetuning import EmbeddingQAFinetuneDataset
from prompt_templates import qa_generation_prompt
from reranker import ReRanker

#standard library imports
import json
import time
import uuid
import os
import re
import random
from datetime import datetime
from typing import List, Dict, Tuple, Union, Literal

#misc
from tqdm import tqdm


class QueryContextGenerator:
    '''
    Class designed for the generation of query/context pairs using a
    Generative LLM. The LLM is used to generate questions from a given
    corpus of text. The query/context pairs can be used to fine-tune 
    an embedding model using a MultipleNegativesRankingLoss loss function
    or can be used to create evaluation datasets for retrieval models.
    '''
    def __init__(self, openai_key: str, model_id: str='gpt-3.5-turbo-0613'):
        self.llm = GPT_Turbo(model=model_id, api_key=openai_key)

    def clean_validate_data(self,
                            data: List[dict], 
                            valid_fields: List[str]=['content', 'summary', 'guest', 'doc_id'],
                            total_chars: int=950
                            ) -> List[dict]:
        '''
        Strip original data chunks so they only contain valid_fields.
        Remove any chunks less than total_chars in size. Prevents LLM
        from asking questions from sparse content. 
        '''
        clean_docs = [{k:v for k,v in d.items() if k in valid_fields} for d in data]
        valid_docs = [d for d in clean_docs if len(d['content']) > total_chars]
        return valid_docs

    def train_val_split(self,
                        data: List[dict],
                        n_train_questions: int, 
                        n_val_questions: int, 
                        n_questions_per_chunk: int=2,
                        total_chars: int=950):
        '''
        Splits corpus into training and validation sets.  Training and 
        validation samples are randomly selected from the corpus. total_chars
        parameter is set based on pre-analysis of average doc length in the 
        training corpus. 
        '''
        clean_data = self.clean_validate_data(data, total_chars=total_chars)
        random.shuffle(clean_data)
        train_index = n_train_questions//n_questions_per_chunk
        valid_index = n_val_questions//n_questions_per_chunk
        end_index = valid_index + train_index
        if end_index > len(clean_data):
            raise ValueError('Cannot create dataset with desired number of questions, try using a larger dataset')
        train_data = clean_data[:train_index]
        valid_data = clean_data[train_index:end_index]
        print(f'Length Training Data: {len(train_data)}')
        print(f'Length Validation Data: {len(valid_data)}')
        return train_data, valid_data

    def generate_qa_embedding_pairs(
                                    self,
                                    data: List[dict],
                                    generate_prompt_tmpl: str=None,
                                    num_questions_per_chunk: int = 2,
                                    ) -> EmbeddingQAFinetuneDataset:
        """
        Generate query/context pairs from a list of documents. The query/context pairs
        can be used for fine-tuning an embedding model using a MultipleNegativesRankingLoss
        or can be used to create an evaluation dataset for retrieval models.

        This function was adapted for this course from the llama_index.finetuning.common module:
        https://github.com/run-llama/llama_index/blob/main/llama_index/finetuning/embeddings/common.py
        """
        generate_prompt_tmpl = qa_generation_prompt if not generate_prompt_tmpl else generate_prompt_tmpl
        queries = {}
        relevant_docs = {}
        corpus = {chunk['doc_id'] : chunk['content'] for chunk in data}
        for chunk in tqdm(data):
            summary = chunk['summary']
            guest = chunk['guest']
            transcript = chunk['content']
            node_id = chunk['doc_id']
            query = generate_prompt_tmpl.format(summary=summary, 
                                                guest=guest,
                                                transcript=transcript,
                                                num_questions_per_chunk=num_questions_per_chunk)
            try:
                response = self.llm.get_chat_completion(prompt=query, temperature=0.1, max_tokens=100)
            except Exception as e:
                print(e)
                continue
            result = str(response).strip().split("\n")
            questions = [
                re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
            ]
            questions = [question for question in questions if len(question) > 0]

            for question in questions:
                question_id = str(uuid.uuid4())
                queries[question_id] = question
                relevant_docs[question_id] = [node_id]

        # construct dataset
        return EmbeddingQAFinetuneDataset(
            queries=queries, corpus=corpus, relevant_docs=relevant_docs
        )

def execute_evaluation(dataset: EmbeddingQAFinetuneDataset, 
                       class_name: str, 
                       retriever: WeaviateClient,
                       reranker: ReRanker=None,
                       alpha: float=0.5,
                       retrieve_limit: int=100,
                       top_k: int=5,
                       chunk_size: int=256,
                       hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef'],
                       search_type: Literal['kw', 'vector', 'hybrid', 'all']='all',
                       display_properties: List[str]=['doc_id', 'content'],
                       dir_outpath: str='./eval_results',
                       include_miss_info: bool=False,
                       user_def_params: dict=None
                       ) -> Union[dict, Tuple[dict, List[dict]]]:
    '''
    Given a dataset, a retriever, and a reranker, evaluate the performance of the retriever and reranker. 
    Returns a dict of kw, vector, and hybrid hit rates and mrr scores. If inlude_miss_info is True, will
    also return a list of kw and vector responses and their associated queries that did not return a hit.

    Args:
    -----
    dataset: EmbeddingQAFinetuneDataset
        Dataset to be used for evaluation
    class_name: str
        Name of Class on Weaviate host to be used for retrieval
    retriever: WeaviateClient
        WeaviateClient object to be used for retrieval 
    reranker: ReRanker
        ReRanker model to be used for results reranking
    alpha: float=0.5
        Weighting factor for BM25 and Vector search.
        alpha can be any number from 0 to 1, defaulting to 0.5:
            alpha = 0 executes a pure keyword search method (BM25)
            alpha = 0.5 weighs the BM25 and vector methods evenly
            alpha = 1 executes a pure vector search method
    retrieve_limit: int=5
        Number of documents to retrieve from Weaviate host
    top_k: int=5
        Number of top results to evaluate
    chunk_size: int=256
        Number of tokens used to chunk text
    hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef']
        List of keys to be used for retrieving HNSW Index parameters from Weaviate host
    search_type: Literal['kw', 'vector', 'hybrid', 'all']='all'
        Type of search to be evaluated.  Options are 'kw', 'vector', 'hybrid', or 'all'
    display_properties: List[str]=['doc_id', 'content']
        List of properties to be returned from Weaviate host for display in response
    dir_outpath: str='./eval_results'
        Directory path for saving results.  Directory will be created if it does not
        already exist. 
    include_miss_info: bool=False
        Option to include queries and their associated search response values
        for queries that are "total misses"
    user_def_params : dict=None
        Option for user to pass in a dictionary of user-defined parameters and their values.
        Will be automatically added to the results_dict if correct type is passed.
    '''
        
    reranker_name = reranker.model_name if reranker else "None"
    
    results_dict = {'n':retrieve_limit, 
                    'top_k': top_k,
                    'alpha': alpha,
                    'Retriever': retriever.model_name_or_path, 
                    'Ranker': reranker_name,
                    'chunk_size': chunk_size,
                    'kw_hit_rate': 0,
                    'kw_mrr': 0,
                    'vector_hit_rate': 0,
                    'vector_mrr': 0,
                    'hybrid_hit_rate':0,
                    'hybrid_mrr': 0,
                    'total_misses': 0,
                    'total_questions':0
                    }
    #add extra params to results_dict
    results_dict = add_params(retriever, class_name, results_dict, user_def_params, hnsw_config_keys)
        
    start = time.perf_counter()
    miss_info = []
    for query_id, q in tqdm(dataset.queries.items(), 'Queries'):
        results_dict['total_questions'] += 1
        hit = False
        #make Keyword, Vector, and Hybrid calls to Weaviate host
        try:
            kw_response = retriever.keyword_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
            vector_response = retriever.vector_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
            hybrid_response = retriever.hybrid_search(request=q, class_name=class_name, alpha=alpha, limit=retrieve_limit, display_properties=display_properties)           
            #rerank returned responses if reranker is provided
            if reranker:
                kw_response = reranker.rerank(kw_response, q, top_k=top_k)
                vector_response = reranker.rerank(vector_response, q, top_k=top_k)
                hybrid_response = reranker.rerank(hybrid_response, q, top_k=top_k)
            
            #collect doc_ids to check for document matches (include only results_top_k)
            kw_doc_ids = {result['doc_id']:i for i, result in enumerate(kw_response[:top_k], 1)}
            vector_doc_ids = {result['doc_id']:i for i, result in enumerate(vector_response[:top_k], 1)}
            hybrid_doc_ids = {result['doc_id']:i for i, result in enumerate(hybrid_response[:top_k], 1)}
            
            #extract doc_id for scoring purposes
            doc_id = dataset.relevant_docs[query_id][0]
     
            #increment hit_rate counters and mrr scores
            if doc_id in kw_doc_ids:
                results_dict['kw_hit_rate'] += 1
                results_dict['kw_mrr'] += 1/kw_doc_ids[doc_id]
                hit = True
            if doc_id in vector_doc_ids:
                results_dict['vector_hit_rate'] += 1
                results_dict['vector_mrr'] += 1/vector_doc_ids[doc_id]
                hit = True
            if doc_id in hybrid_doc_ids:
                results_dict['hybrid_hit_rate'] += 1
                results_dict['hybrid_mrr'] += 1/hybrid_doc_ids[doc_id]
                hit = True
            # if no hits, let's capture that
            if not hit:
                results_dict['total_misses'] += 1
                miss_info.append({'query': q, 
                                  'answer': dataset.corpus[doc_id],
                                  'doc_id': doc_id,
                                  'kw_response': kw_response,
                                  'vector_response': vector_response, 
                                  'hybrid_response': hybrid_response})
        except Exception as e:
            print(e)
            continue

    #use raw counts to calculate final scores
    calc_hit_rate_scores(results_dict, search_type=search_type)
    calc_mrr_scores(results_dict, search_type=search_type)
    
    end = time.perf_counter() - start
    print(f'Total Processing Time: {round(end/60, 2)} minutes')
    record_results(results_dict, chunk_size, dir_outpath=dir_outpath, as_text=True)
    
    if include_miss_info:
        return results_dict, miss_info
    return results_dict

def calc_hit_rate_scores(results_dict: Dict[str, Union[str, int]], 
                         search_type: Literal['kw', 'vector', 'hybrid', 'all']=['kw', 'vector']
                         ) -> None:
    if search_type == 'all':
        search_type = ['kw', 'vector', 'hybrid']
    for prefix in search_type:
        results_dict[f'{prefix}_hit_rate'] = round(results_dict[f'{prefix}_hit_rate']/results_dict['total_questions'],2)

def calc_mrr_scores(results_dict: Dict[str, Union[str, int]],
                    search_type: Literal['kw', 'vector', 'hybrid', 'all']=['kw', 'vector']
                    ) -> None:
    if search_type == 'all':
        search_type = ['kw', 'vector', 'hybrid']
    for prefix in search_type:
        results_dict[f'{prefix}_mrr'] = round(results_dict[f'{prefix}_mrr']/results_dict['total_questions'],2)

def create_dir(dir_path: str) -> None:
    '''
    Checks if directory exists, and creates new directory
    if it does not exist
    '''
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

def record_results(results_dict: Dict[str, Union[str, int]], 
                   chunk_size: int, 
                   dir_outpath: str='./eval_results',
                   as_text: bool=False
                   ) -> None:
    '''
    Write results to output file in either txt or json format

    Args:
    -----
    results_dict: Dict[str, Union[str, int]]
        Dictionary containing results of evaluation
    chunk_size: int
        Size of text chunks in tokens
    dir_outpath: str
        Path to output directory.  Directory only, filename is hardcoded
        as part of this function.
    as_text: bool
        If True, write results as text file.  If False, write as json file.
    '''
    create_dir(dir_outpath)
    time_marker = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    ext = 'txt' if as_text else 'json'
    path = os.path.join(dir_outpath, f'retrieval_eval_{chunk_size}_{time_marker}.{ext}')
    if as_text:
        with open(path, 'a') as f:
            f.write(f"{results_dict}\n")
    else: 
        with open(path, 'w') as f:
            json.dump(results_dict, f, indent=4)

def add_params(client: WeaviateClient, 
               class_name: str, 
               results_dict: dict, 
               param_options: dict, 
               hnsw_config_keys: List[str]
              ) -> dict:
    hnsw_params = {k:v for k,v in client.show_class_config(class_name)['vectorIndexConfig'].items() if k in hnsw_config_keys}
    if hnsw_params:
        results_dict = {**results_dict, **hnsw_params}
    if param_options and isinstance(param_options, dict):
        results_dict = {**results_dict, **param_options}
    return results_dict