RAG_Eval / evaluation /metrics /retrieval_metrics.py
Rom89823974978's picture
Resolved errors shown by tests
f868144
"""Classical retrieval metrics (Precision@k, Recall@k, MRR, MAP)."""
from __future__ import annotations
from typing import List, Sequence, Set
def precision_at_k(retrieved: Sequence[str], relevant: Set[str], k: int) -> float:
retrieved_k = retrieved[:k]
if not retrieved_k:
return 0.0
return len([d for d in retrieved_k if d in relevant]) / len(retrieved_k)
def recall_at_k(retrieved: Sequence[str], relevant: Set[str], k: int) -> float:
retrieved_k = retrieved[:k]
if not relevant:
return 0.0
return len([d for d in retrieved_k if d in relevant]) / len(relevant)
def mean_reciprocal_rank(retrieved: Sequence[str], relevant: Set[str]) -> float:
for idx, doc_id in enumerate(retrieved, start=1):
if doc_id in relevant:
return 1.0 / idx
return 0.0
def average_precision(retrieved: Sequence[str], relevant: Set[str]) -> float:
if not relevant:
return 0.0
hits = 0
sum_precisions = 0.0
for idx, doc_id in enumerate(retrieved, start=1):
if doc_id in relevant:
hits += 1
sum_precisions += hits / idx
# Divide by the total number of relevant items (len(relevant)), not by hits
return float(sum_precisions) / float(len(relevant))