|
|
import numpy as np |
|
|
from sklearn.metrics import ndcg_score |
|
|
from src.retriever import ClinicalCaseRetriever, DummyRetriever |
|
|
|
|
|
def retrieval_metrics(retriever_instance: ClinicalCaseRetriever, queries: list[str], gold_ids: list[str], k: int = 5) -> dict | None: |
|
|
""" |
|
|
Calculates retrieval metrics for a set of queries. |
|
|
|
|
|
Args: |
|
|
retriever_instance: An initialized ClinicalCaseRetriever instance. |
|
|
queries: A list of query strings. |
|
|
gold_ids: A list of the expected 'case_id' strings for each query. |
|
|
k: The number of top results to consider for Hit@k and NDCG@k. |
|
|
|
|
|
Returns: |
|
|
A dictionary containing Hit@k, MRR, and NDCG@k scores, or None on error. |
|
|
""" |
|
|
|
|
|
hits, reciprocal_ranks, ndcgs = [], [], [] |
|
|
print(f"\nCalculating retrieval metrics for {len(queries)} queries (k={k})...") |
|
|
|
|
|
|
|
|
for q_idx, (q, gold) in enumerate(zip(queries, gold_ids)): |
|
|
print(f"\nProcessing query {q_idx+1}/{len(queries)}: '{q}' (Expected ID: '{gold}')") |
|
|
retrieved_cases, scores = retriever_instance.retrieve_relevant_case(q, top_k=k, return_scores=True) |
|
|
|
|
|
|
|
|
retrieved_ids = [c.get('case_id', 'N/A') for c in retrieved_cases] |
|
|
print(f"Retrieved IDs: {retrieved_ids}") |
|
|
print(f"Retrieved Scores: {[round(s, 4) for s in scores]}") |
|
|
|
|
|
|
|
|
is_hit = int(gold in retrieved_ids) |
|
|
hits.append(is_hit) |
|
|
|
|
|
rank = 0 |
|
|
if is_hit: |
|
|
rank = retrieved_ids.index(gold) + 1 |
|
|
reciprocal_ranks.append(1.0 / rank) |
|
|
else: |
|
|
reciprocal_ranks.append(0.0) |
|
|
|
|
|
|
|
|
true_relevance = np.asarray([[1.0 if gid == gold else 0.0 for gid in retrieved_ids]]) |
|
|
predicted_scores = np.asarray([scores]) |
|
|
|
|
|
current_ndcg = 0.0 |
|
|
if true_relevance.shape[1] > 0: |
|
|
ndcg_k = min(k, true_relevance.shape[1]) |
|
|
current_ndcg = ndcg_score(true_relevance, predicted_scores, k=ndcg_k) |
|
|
ndcgs.append(current_ndcg) |
|
|
|
|
|
print(f"Hit: {is_hit}, Rank: {rank if rank > 0 else 'N/A'}, NDCG@{k}: {current_ndcg:.4f}") |
|
|
|
|
|
|
|
|
avg_hit = np.mean(hits) if hits else 0.0 |
|
|
avg_mrr = np.mean(reciprocal_ranks) if reciprocal_ranks else 0.0 |
|
|
avg_ndcg = np.mean(ndcgs) if ndcgs else 0.0 |
|
|
|
|
|
print(f"\n--- Overall Retrieval Results (k={k}) --- ") |
|
|
print(f"Average Hit@{k}: {avg_hit:.4f}") |
|
|
print(f"Average MRR: {avg_mrr:.4f}") |
|
|
print(f"Average NDCG@{k}: {avg_ndcg:.4f}") |
|
|
|
|
|
return {f"Hit@{k}": avg_hit, |
|
|
f"MRR": avg_mrr, |
|
|
f"NDCG@{k}": avg_ndcg} |