Spaces:
Sleeping
Sleeping
"""Classical retrieval metrics (Precision@k, Recall@k, MRR, MAP).""" | |
from __future__ import annotations | |
from typing import List, Sequence, Set | |
def precision_at_k(retrieved: Sequence[str], relevant: Set[str], k: int) -> float: | |
retrieved_k = retrieved[:k] | |
if not retrieved_k: | |
return 0.0 | |
return len([d for d in retrieved_k if d in relevant]) / len(retrieved_k) | |
def recall_at_k(retrieved: Sequence[str], relevant: Set[str], k: int) -> float: | |
retrieved_k = retrieved[:k] | |
if not relevant: | |
return 0.0 | |
return len([d for d in retrieved_k if d in relevant]) / len(relevant) | |
def mean_reciprocal_rank(retrieved: Sequence[str], relevant: Set[str]) -> float: | |
for idx, doc_id in enumerate(retrieved, start=1): | |
if doc_id in relevant: | |
return 1.0 / idx | |
return 0.0 | |
def average_precision(retrieved: Sequence[str], relevant: Set[str]) -> float: | |
if not relevant: | |
return 0.0 | |
hits = 0 | |
sum_precisions = 0.0 | |
for idx, doc_id in enumerate(retrieved, start=1): | |
if doc_id in relevant: | |
hits += 1 | |
sum_precisions += hits / idx | |
# Divide by the total number of relevant items (len(relevant)), not by hits | |
return float(sum_precisions) / float(len(relevant)) | |