File size: 1,260 Bytes
8521f60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f868144
 
8521f60
 
 
 
 
 
f868144
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""Classical retrieval metrics (Precision@k, Recall@k, MRR, MAP)."""

from __future__ import annotations
from typing import List, Sequence, Set


def precision_at_k(retrieved: Sequence[str], relevant: Set[str], k: int) -> float:
    retrieved_k = retrieved[:k]
    if not retrieved_k:
        return 0.0
    return len([d for d in retrieved_k if d in relevant]) / len(retrieved_k)

def recall_at_k(retrieved: Sequence[str], relevant: Set[str], k: int) -> float:
    retrieved_k = retrieved[:k]
    if not relevant:
        return 0.0
    return len([d for d in retrieved_k if d in relevant]) / len(relevant)

def mean_reciprocal_rank(retrieved: Sequence[str], relevant: Set[str]) -> float:
    for idx, doc_id in enumerate(retrieved, start=1):
        if doc_id in relevant:
            return 1.0 / idx
    return 0.0

def average_precision(retrieved: Sequence[str], relevant: Set[str]) -> float:
    if not relevant:
        return 0.0
    hits = 0
    sum_precisions = 0.0
    for idx, doc_id in enumerate(retrieved, start=1):
        if doc_id in relevant:
            hits += 1
            sum_precisions += hits / idx
    # Divide by the total number of relevant items (len(relevant)), not by hits
    return float(sum_precisions) / float(len(relevant))