Spaces:
Runtime error
Runtime error
# NOTE: This code is taken from the original KILT library's retrieval evaluation script | |
# https://github.com/facebookresearch/KILT/blob/9bcb119a7ed5fda88826058b062d0e45c726c676/kilt/eval_retrieval.py | |
import argparse | |
import pprint | |
import json | |
from collections import defaultdict, OrderedDict | |
import os | |
from pyserini.query_iterator import KiltQueryIterator | |
########################################################################################## | |
# Replaced: | |
# from kilt import kilt_utils | |
# With the following directly imported code: | |
def load_data(filename): | |
data = [] | |
with open(filename, "r") as fin: | |
lines = fin.readlines() | |
for line in lines: | |
data.append(json.loads(line)) | |
return data | |
########################################################################################## | |
# Replaced: | |
# from kilt import eval_downstream | |
# With the following directly imported code: | |
def validate_input(gold_records, guess_records): | |
if len(gold_records) != len(guess_records): | |
print( | |
"WARNING: DIFFERENT SIZE gold: {} guess: {}".format( | |
len(gold_records), len(guess_records) | |
) | |
) | |
# align order | |
gold_ids = [] | |
for gold in gold_records: | |
assert str(gold["id"]).strip() not in gold_ids, "Gold IDs should be unique" | |
gold_ids.append(str(gold["id"]).strip()) | |
id2guess_record = {} | |
for guess in guess_records: | |
assert ( | |
str(guess["id"]).strip() not in id2guess_record | |
), "Prediction IDs should be unique" | |
id2guess_record[str(guess["id"]).strip()] = guess | |
guess_records = [] | |
for id in gold_ids: | |
if id in id2guess_record: | |
guess_records.append(id2guess_record[id]) | |
else: | |
raise ValueError("ERROR: no prediction provided for id: {}".format(id)) | |
return gold_records, guess_records | |
########################################################################################## | |
def _remove_duplicates(obj): | |
obj_tmp = [] | |
for o in obj: | |
if o not in obj_tmp: | |
obj_tmp.append(o) | |
return obj_tmp | |
def _get_ids_list(datapoint, rank_keys, verbose=False): | |
# collect all gold ids | |
ids_list = [] | |
for output in datapoint["output"]: | |
current_ids_list = [] | |
if "provenance" in output: | |
for provenance in output["provenance"]: | |
if any(rank_key not in provenance for rank_key in rank_keys): | |
missing = set(rank_keys) - set( | |
list(provenance.keys()) | |
).intersection(set(rank_keys)) | |
if verbose: | |
print( | |
f"WARNING: missing key(s) {missing} in provenance, unable to compute retrieval for those." | |
) | |
else: | |
current_ids_list.append( | |
"+".join( | |
[ | |
str(provenance[rank_key]).strip() | |
for rank_key in rank_keys | |
] | |
) | |
) | |
ids_list.append(_remove_duplicates(current_ids_list)) # remove duplicates | |
# consider only unique ids | |
return ids_list | |
def get_rank(guess_item, gold_item, k, rank_keys, verbose=False): | |
""" | |
The main idea is to consider each evidence set as a single point in the rank. | |
The score in the rank for an evidence set is given by the lowest scored evidence in the set. | |
""" | |
assert k > 0, "k must be a positive integer grater than 0." | |
rank = [] | |
num_distinct_evidence_sets = 0 | |
guess_ids = _get_ids_list(guess_item, rank_keys)[0] | |
if guess_ids and len(guess_ids) > 0: | |
# 1. collect evidence sets and their sizes | |
evidence_sets = [] | |
e_size = defaultdict(int) | |
for output in gold_item["output"]: | |
if "provenance" in output: | |
e_set = { | |
"+".join( | |
[str(provenance[rank_key]).strip() for rank_key in rank_keys] | |
) | |
for provenance in output["provenance"] | |
} | |
if e_set not in evidence_sets: # no duplicate evidence set | |
evidence_sets.append(e_set) | |
e_size[len(e_set)] += 1 | |
num_distinct_evidence_sets = len(evidence_sets) | |
# 2. check what's the minimum number of predicted pages needed to get a robust P/R@k | |
min_prediction_size = 0 | |
c = 0 | |
for size, freq in sorted(e_size.items(), reverse=True): | |
for _ in range(freq): | |
min_prediction_size += size | |
c += 1 | |
if c == k: | |
break | |
if c == k: | |
break | |
# if the number of evidence sets is smaller than k | |
min_prediction_size += k - c | |
if verbose and len(guess_ids) < min_prediction_size: | |
print( | |
f"WARNING: you should provide at least {min_prediction_size} provenance items for a robust recall@{k} computation (you provided {len(guess_ids)} item(s))." | |
) | |
# 3. rank by gruping pages in each evidence set (each evidence set count as 1), | |
# the position in the rank of each evidence set is given by the last page in guess_ids | |
# non evidence pages counts as 1 | |
rank = [] | |
for guess_id in guess_ids: | |
guess_id = str(guess_id).strip() | |
found = False | |
for idx, e_set in enumerate(evidence_sets): | |
e_set_id = f"evidence_set:{idx}" | |
if guess_id in e_set: | |
found = True | |
# remove from the rank previous points referring to this evidence set | |
if e_set_id in rank: | |
rank.remove(e_set_id) | |
# remove the guess_id from the evidence set | |
e_set.remove(guess_id) | |
if len(e_set) == 0: | |
# it was the last evidence, it counts as true in the rank | |
rank.append(True) | |
else: | |
# add a point for this partial evidence set | |
rank.append(e_set_id) | |
if not found: | |
rank.append(False) | |
return rank, num_distinct_evidence_sets | |
# 1. Precision computation | |
def _precision_at_k(rank, k): | |
# precision @ k | |
p = rank[:k].count(True) / k | |
return p | |
# 2. Recall computation | |
def _recall_at_k(rank, num_distinct_evidence_sets, k): | |
r = rank[:k].count(True) / num_distinct_evidence_sets | |
return r | |
# 3. Success rate computation | |
def _success_rate_at_k(rank, k): | |
# success rate @ k | |
p = int(True in rank[:k]) | |
return p | |
def _computeRprec(guess_ids, gold_ids): | |
R = len(gold_ids) | |
num = 0 | |
for prediction in guess_ids[:R]: | |
if str(prediction).strip() in gold_ids: | |
num += 1 | |
Rprec = num / R if R > 0 else 0 | |
return Rprec | |
# R-precision https://link.springer.com/referenceworkentry/10.1007%2F978-0-387-39940-9_486 | |
def rprecision(guess_item, gold_item, rank_keys): | |
gold_ids_list = _get_ids_list(gold_item, rank_keys) | |
guess_ids = _get_ids_list(guess_item, rank_keys)[0] | |
Rprec_vector = [] | |
for gold_ids in gold_ids_list: | |
Rprec = _computeRprec(guess_ids, gold_ids) | |
Rprec_vector.append(Rprec) | |
return max(Rprec_vector) | |
def get_ranking_metrics(guess_item, gold_item, ks, rank_keys): | |
Rprec = 0 | |
P_at_k = {"precision@{}".format(k): 0 for k in sorted(ks) if k > 0} | |
R_at_k = {"recall@{}".format(k): 0 for k in sorted(ks) if k > 1} | |
S_at_k = {"success_rate@{}".format(k): 0 for k in sorted(ks) if k > 1} | |
assert ( | |
"output" in guess_item and len(guess_item["output"]) == 1 | |
), f"guess should provide exactly one output for {guess_item['id']}" | |
Rprec = rprecision(guess_item, gold_item, rank_keys=rank_keys) | |
for k in ks: | |
# 0. get rank | |
rank, num_distinct_evidence_sets = get_rank( | |
guess_item, gold_item, k, rank_keys=rank_keys | |
) | |
if num_distinct_evidence_sets > 0: | |
# 1. precision | |
P_at_k["precision@{}".format(k)] = _precision_at_k(rank, k) | |
# 2. recall | |
R_at_k["recall@{}".format(k)] = _recall_at_k( | |
rank, num_distinct_evidence_sets, k | |
) | |
# 3. success rate | |
S_at_k["success_rate@{}".format(k)] = _success_rate_at_k(rank, k) | |
# else: | |
# print( | |
# "WARNING: the number of distinct evidence sets is 0 for {}".format( | |
# gold_item | |
# ) | |
# ) | |
return {"Rprec": Rprec, **P_at_k, **R_at_k, **S_at_k} | |
def compute(gold_dataset, guess_dataset, ks, rank_keys): | |
ks = sorted([int(x) for x in ks]) | |
result = OrderedDict() | |
result["Rprec"] = 0.0 | |
for k in ks: | |
if k > 0: | |
result["precision@{}".format(k)] = 0.0 | |
if k > 1: | |
result["recall@{}".format(k)] = 0.0 | |
result["success_rate@{}".format(k)] = 0.0 | |
assert len(guess_dataset) == len( | |
gold_dataset | |
), "different size gold: {} guess: {}".format(len(guess_dataset), len(gold_dataset)) | |
for gold, guess in zip(guess_dataset, gold_dataset): | |
assert ( | |
str(gold["id"]).strip() == str(guess["id"]).strip() | |
), "Items must have same order with same IDs" | |
for guess_item, gold_item in zip(guess_dataset, gold_dataset): | |
ranking_metrics = get_ranking_metrics(guess_item, gold_item, ks, rank_keys) | |
result["Rprec"] += ranking_metrics["Rprec"] | |
for k in ks: | |
if k > 0: | |
result["precision@{}".format(k)] += ranking_metrics[ | |
"precision@{}".format(k) | |
] | |
if k > 1: | |
result["recall@{}".format(k)] += ranking_metrics["recall@{}".format(k)] | |
result["success_rate@{}".format(k)] += ranking_metrics[ | |
"success_rate@{}".format(k) | |
] | |
if len(guess_dataset) > 0: | |
result["Rprec"] /= len(guess_dataset) | |
for k in ks: | |
if k > 0: | |
result["precision@{}".format(k)] /= len(guess_dataset) | |
if k > 1: | |
result["recall@{}".format(k)] /= len(guess_dataset) | |
result["success_rate@{}".format(k)] /= len(guess_dataset) | |
return result | |
def evaluate(gold, guess, ks, rank_keys): | |
pp = pprint.PrettyPrinter(indent=4) | |
gold_dataset = load_data(gold) | |
guess_dataset = load_data(guess) | |
# 0. validate input | |
gold_dataset, guess_dataset = validate_input( | |
gold_dataset, guess_dataset | |
) | |
# 1. get retrieval metrics | |
result = compute(gold_dataset, guess_dataset, ks, rank_keys) | |
pp.pprint(result) | |
return result | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("guess", help="Guess KILT file") | |
parser.add_argument("gold", help="Gold KILT file") | |
parser.add_argument( | |
"--ks", | |
type=str, | |
required=False, | |
default="1,5,10,20", | |
help="Comma separated list of positive integers for recall@k and precision@k", | |
) | |
parser.add_argument( | |
"--rank_keys", | |
type=str, | |
required=False, | |
default="wikipedia_id", | |
help="Comma separated list of rank keys for recall@k and precision@k", | |
) | |
args = parser.parse_args() | |
args.ks = [int(k) for k in args.ks.split(",")] | |
args.rank_keys = [rank_key for rank_key in args.rank_keys.split(",")] | |
########################################################################################## | |
# Pyserini change: | |
# Download gold file if necessary | |
gold = args.gold | |
if not os.path.exists(args.gold): | |
gold = KiltQueryIterator.download_kilt_topics(gold) | |
########################################################################################## | |
evaluate(gold, args.guess, args.ks, args.rank_keys) | |