muryshev's picture
Init
729d130
raw
history blame
No virus
1.6 kB
import numpy as np
def calculate_metrics_at_k(pred, true, k, dynamic_topk=False):
precisions_at_k = []
recalls_at_k = []
f1_scores_at_k = []
for query_id in pred:
if dynamic_topk:
k = len(set(pred[query_id]))
retrieved_documents = set(pred[query_id][:k])
relevant_documents = set(true[query_id])
true_positives = len(retrieved_documents.intersection(relevant_documents))
if not len(retrieved_documents) and not len(relevant_documents):
precisions_at_k.append(1)
recalls_at_k.append(1)
f1_scores_at_k.append(1)
continue
# precision
precision_at_k = true_positives / k if k else 0
precisions_at_k.append(precision_at_k)
# recall
recall_at_k = true_positives / len(relevant_documents) if relevant_documents else 0
recalls_at_k.append(recall_at_k)
# f1
if precision_at_k + recall_at_k > 0:
f1_at_k = 2 * (precision_at_k * recall_at_k) / (precision_at_k + recall_at_k)
else:
f1_at_k = 0
f1_scores_at_k.append(f1_at_k)
# Average Precision@k, Recall@k, and F1@k
avg_precision_at_k = np.mean(precisions_at_k) if precisions_at_k else 0
avg_recall_at_k = np.mean(recalls_at_k) if recalls_at_k else 0
avg_f1_at_k = np.mean(f1_scores_at_k) if f1_scores_at_k else 0
return {
f"avg_precision@{k}": avg_precision_at_k,
f"avg_recall@{k}": avg_recall_at_k,
f"avg_f1@{k}": avg_f1_at_k
}