Spaces:
Runtime error
Runtime error
import numpy as np | |
def calculate_metrics_at_k(pred, true, k, dynamic_topk=False): | |
precisions_at_k = [] | |
recalls_at_k = [] | |
f1_scores_at_k = [] | |
for query_id in pred: | |
if dynamic_topk: | |
k = len(set(pred[query_id])) | |
retrieved_documents = set(pred[query_id][:k]) | |
relevant_documents = set(true[query_id]) | |
true_positives = len(retrieved_documents.intersection(relevant_documents)) | |
if not len(retrieved_documents) and not len(relevant_documents): | |
precisions_at_k.append(1) | |
recalls_at_k.append(1) | |
f1_scores_at_k.append(1) | |
continue | |
# precision | |
precision_at_k = true_positives / k if k else 0 | |
precisions_at_k.append(precision_at_k) | |
# recall | |
recall_at_k = true_positives / len(relevant_documents) if relevant_documents else 0 | |
recalls_at_k.append(recall_at_k) | |
# f1 | |
if precision_at_k + recall_at_k > 0: | |
f1_at_k = 2 * (precision_at_k * recall_at_k) / (precision_at_k + recall_at_k) | |
else: | |
f1_at_k = 0 | |
f1_scores_at_k.append(f1_at_k) | |
# Average Precision@k, Recall@k, and F1@k | |
avg_precision_at_k = np.mean(precisions_at_k) if precisions_at_k else 0 | |
avg_recall_at_k = np.mean(recalls_at_k) if recalls_at_k else 0 | |
avg_f1_at_k = np.mean(f1_scores_at_k) if f1_scores_at_k else 0 | |
return { | |
f"avg_precision@{k}": avg_precision_at_k, | |
f"avg_recall@{k}": avg_recall_at_k, | |
f"avg_f1@{k}": avg_f1_at_k | |
} | |