vectorsearch / helpers.py
JPBianchi's picture
temp before HF pull
30ffb9e
from typing import List, Tuple, Dict, Any
import time
from tqdm.notebook import tqdm
from rich import print
from retrieval_evaluation import calc_hit_rate_scores, calc_mrr_scores, record_results, add_params
from llama_index.finetuning import EmbeddingQAFinetuneDataset
from weaviate_interface import WeaviateClient
def retrieval_evaluation(dataset: EmbeddingQAFinetuneDataset,
class_name: str,
retriever: WeaviateClient,
retrieve_limit: int=5,
chunk_size: int=256,
hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef'],
display_properties: List[str]=['doc_id', 'guest', 'content'],
dir_outpath: str='./eval_results',
include_miss_info: bool=False,
user_def_params: Dict[str,Any]=None
) -> Dict[str, str|int|float]:
'''
Given a dataset and a retriever evaluate the performance of the retriever. Returns a dict of kw and vector
hit rates and mrr scores. If inlude_miss_info is True, will also return a list of kw and vector responses
and their associated queries that did not return a hit, for deeper analysis. Text file with results output
is automatically saved in the dir_outpath directory.
Args:
-----
dataset: EmbeddingQAFinetuneDataset
Dataset to be used for evaluation
class_name: str
Name of Class on Weaviate host to be used for retrieval
retriever: WeaviateClient
WeaviateClient object to be used for retrieval
retrieve_limit: int=5
Number of documents to retrieve from Weaviate host
chunk_size: int=256
Number of tokens used to chunk text. This value is purely for results
recording purposes and does not affect results.
display_properties: List[str]=['doc_id', 'content']
List of properties to be returned from Weaviate host for display in response
dir_outpath: str='./eval_results'
Directory path for saving results. Directory will be created if it does not
already exist.
include_miss_info: bool=False
Option to include queries and their associated kw and vector response values
for queries that are "total misses"
user_def_params : dict=None
Option for user to pass in a dictionary of user-defined parameters and their values.
'''
results_dict = {'n':retrieve_limit,
'Retriever': retriever.model_name_or_path,
'chunk_size': chunk_size,
'kw_hit_rate': 0,
'kw_mrr': 0,
'vector_hit_rate': 0,
'vector_mrr': 0,
'total_misses': 0,
'total_questions':0
}
#add hnsw configs and user defined params (if any)
results_dict = add_params(retriever, class_name, results_dict, user_def_params, hnsw_config_keys)
start = time.perf_counter()
miss_info = []
for query_id, q in tqdm(dataset.queries.items(), 'Queries'):
results_dict['total_questions'] += 1
hit = False
#make Keyword, Vector, and Hybrid calls to Weaviate host
try:
kw_response = retriever.keyword_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
vector_response = retriever.vector_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
#collect doc_ids and position of doc_ids to check for document matches
kw_doc_ids = {result['doc_id']:i for i, result in enumerate(kw_response, 1)}
vector_doc_ids = {result['doc_id']:i for i, result in enumerate(vector_response, 1)}
#extract doc_id for scoring purposes
doc_id = dataset.relevant_docs[query_id][0]
#increment hit_rate counters and mrr scores
if doc_id in kw_doc_ids:
results_dict['kw_hit_rate'] += 1
results_dict['kw_mrr'] += 1/kw_doc_ids[doc_id]
hit = True
if doc_id in vector_doc_ids:
results_dict['vector_hit_rate'] += 1
results_dict['vector_mrr'] += 1/vector_doc_ids[doc_id]
hit = True
# if no hits, let's capture that
if not hit:
results_dict['total_misses'] += 1
miss_info.append({'query': q, 'kw_response': kw_response, 'vector_response': vector_response})
except Exception as e:
print(e)
continue
#use raw counts to calculate final scores
calc_hit_rate_scores(results_dict)
calc_mrr_scores(results_dict)
end = time.perf_counter() - start
print(f'Total Processing Time: {round(end/60, 2)} minutes')
record_results(results_dict, chunk_size, dir_outpath=dir_outpath, as_text=True)
if include_miss_info:
return results_dict, miss_info
return results_dict