NetsPresso_QA / pyserini /eval /evaluate_kilt_retrieval.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
raw
history blame
No virus
12.1 kB
# NOTE: This code is taken from the original KILT library's retrieval evaluation script
# https://github.com/facebookresearch/KILT/blob/9bcb119a7ed5fda88826058b062d0e45c726c676/kilt/eval_retrieval.py
import argparse
import pprint
import json
from collections import defaultdict, OrderedDict
import os
from pyserini.query_iterator import KiltQueryIterator
##########################################################################################
# Replaced:
# from kilt import kilt_utils
# With the following directly imported code:
def load_data(filename):
data = []
with open(filename, "r") as fin:
lines = fin.readlines()
for line in lines:
data.append(json.loads(line))
return data
##########################################################################################
# Replaced:
# from kilt import eval_downstream
# With the following directly imported code:
def validate_input(gold_records, guess_records):
if len(gold_records) != len(guess_records):
print(
"WARNING: DIFFERENT SIZE gold: {} guess: {}".format(
len(gold_records), len(guess_records)
)
)
# align order
gold_ids = []
for gold in gold_records:
assert str(gold["id"]).strip() not in gold_ids, "Gold IDs should be unique"
gold_ids.append(str(gold["id"]).strip())
id2guess_record = {}
for guess in guess_records:
assert (
str(guess["id"]).strip() not in id2guess_record
), "Prediction IDs should be unique"
id2guess_record[str(guess["id"]).strip()] = guess
guess_records = []
for id in gold_ids:
if id in id2guess_record:
guess_records.append(id2guess_record[id])
else:
raise ValueError("ERROR: no prediction provided for id: {}".format(id))
return gold_records, guess_records
##########################################################################################
def _remove_duplicates(obj):
obj_tmp = []
for o in obj:
if o not in obj_tmp:
obj_tmp.append(o)
return obj_tmp
def _get_ids_list(datapoint, rank_keys, verbose=False):
# collect all gold ids
ids_list = []
for output in datapoint["output"]:
current_ids_list = []
if "provenance" in output:
for provenance in output["provenance"]:
if any(rank_key not in provenance for rank_key in rank_keys):
missing = set(rank_keys) - set(
list(provenance.keys())
).intersection(set(rank_keys))
if verbose:
print(
f"WARNING: missing key(s) {missing} in provenance, unable to compute retrieval for those."
)
else:
current_ids_list.append(
"+".join(
[
str(provenance[rank_key]).strip()
for rank_key in rank_keys
]
)
)
ids_list.append(_remove_duplicates(current_ids_list)) # remove duplicates
# consider only unique ids
return ids_list
def get_rank(guess_item, gold_item, k, rank_keys, verbose=False):
"""
The main idea is to consider each evidence set as a single point in the rank.
The score in the rank for an evidence set is given by the lowest scored evidence in the set.
"""
assert k > 0, "k must be a positive integer grater than 0."
rank = []
num_distinct_evidence_sets = 0
guess_ids = _get_ids_list(guess_item, rank_keys)[0]
if guess_ids and len(guess_ids) > 0:
# 1. collect evidence sets and their sizes
evidence_sets = []
e_size = defaultdict(int)
for output in gold_item["output"]:
if "provenance" in output:
e_set = {
"+".join(
[str(provenance[rank_key]).strip() for rank_key in rank_keys]
)
for provenance in output["provenance"]
}
if e_set not in evidence_sets: # no duplicate evidence set
evidence_sets.append(e_set)
e_size[len(e_set)] += 1
num_distinct_evidence_sets = len(evidence_sets)
# 2. check what's the minimum number of predicted pages needed to get a robust P/R@k
min_prediction_size = 0
c = 0
for size, freq in sorted(e_size.items(), reverse=True):
for _ in range(freq):
min_prediction_size += size
c += 1
if c == k:
break
if c == k:
break
# if the number of evidence sets is smaller than k
min_prediction_size += k - c
if verbose and len(guess_ids) < min_prediction_size:
print(
f"WARNING: you should provide at least {min_prediction_size} provenance items for a robust recall@{k} computation (you provided {len(guess_ids)} item(s))."
)
# 3. rank by gruping pages in each evidence set (each evidence set count as 1),
# the position in the rank of each evidence set is given by the last page in guess_ids
# non evidence pages counts as 1
rank = []
for guess_id in guess_ids:
guess_id = str(guess_id).strip()
found = False
for idx, e_set in enumerate(evidence_sets):
e_set_id = f"evidence_set:{idx}"
if guess_id in e_set:
found = True
# remove from the rank previous points referring to this evidence set
if e_set_id in rank:
rank.remove(e_set_id)
# remove the guess_id from the evidence set
e_set.remove(guess_id)
if len(e_set) == 0:
# it was the last evidence, it counts as true in the rank
rank.append(True)
else:
# add a point for this partial evidence set
rank.append(e_set_id)
if not found:
rank.append(False)
return rank, num_distinct_evidence_sets
# 1. Precision computation
def _precision_at_k(rank, k):
# precision @ k
p = rank[:k].count(True) / k
return p
# 2. Recall computation
def _recall_at_k(rank, num_distinct_evidence_sets, k):
r = rank[:k].count(True) / num_distinct_evidence_sets
return r
# 3. Success rate computation
def _success_rate_at_k(rank, k):
# success rate @ k
p = int(True in rank[:k])
return p
def _computeRprec(guess_ids, gold_ids):
R = len(gold_ids)
num = 0
for prediction in guess_ids[:R]:
if str(prediction).strip() in gold_ids:
num += 1
Rprec = num / R if R > 0 else 0
return Rprec
# R-precision https://link.springer.com/referenceworkentry/10.1007%2F978-0-387-39940-9_486
def rprecision(guess_item, gold_item, rank_keys):
gold_ids_list = _get_ids_list(gold_item, rank_keys)
guess_ids = _get_ids_list(guess_item, rank_keys)[0]
Rprec_vector = []
for gold_ids in gold_ids_list:
Rprec = _computeRprec(guess_ids, gold_ids)
Rprec_vector.append(Rprec)
return max(Rprec_vector)
def get_ranking_metrics(guess_item, gold_item, ks, rank_keys):
Rprec = 0
P_at_k = {"precision@{}".format(k): 0 for k in sorted(ks) if k > 0}
R_at_k = {"recall@{}".format(k): 0 for k in sorted(ks) if k > 1}
S_at_k = {"success_rate@{}".format(k): 0 for k in sorted(ks) if k > 1}
assert (
"output" in guess_item and len(guess_item["output"]) == 1
), f"guess should provide exactly one output for {guess_item['id']}"
Rprec = rprecision(guess_item, gold_item, rank_keys=rank_keys)
for k in ks:
# 0. get rank
rank, num_distinct_evidence_sets = get_rank(
guess_item, gold_item, k, rank_keys=rank_keys
)
if num_distinct_evidence_sets > 0:
# 1. precision
P_at_k["precision@{}".format(k)] = _precision_at_k(rank, k)
# 2. recall
R_at_k["recall@{}".format(k)] = _recall_at_k(
rank, num_distinct_evidence_sets, k
)
# 3. success rate
S_at_k["success_rate@{}".format(k)] = _success_rate_at_k(rank, k)
# else:
# print(
# "WARNING: the number of distinct evidence sets is 0 for {}".format(
# gold_item
# )
# )
return {"Rprec": Rprec, **P_at_k, **R_at_k, **S_at_k}
def compute(gold_dataset, guess_dataset, ks, rank_keys):
ks = sorted([int(x) for x in ks])
result = OrderedDict()
result["Rprec"] = 0.0
for k in ks:
if k > 0:
result["precision@{}".format(k)] = 0.0
if k > 1:
result["recall@{}".format(k)] = 0.0
result["success_rate@{}".format(k)] = 0.0
assert len(guess_dataset) == len(
gold_dataset
), "different size gold: {} guess: {}".format(len(guess_dataset), len(gold_dataset))
for gold, guess in zip(guess_dataset, gold_dataset):
assert (
str(gold["id"]).strip() == str(guess["id"]).strip()
), "Items must have same order with same IDs"
for guess_item, gold_item in zip(guess_dataset, gold_dataset):
ranking_metrics = get_ranking_metrics(guess_item, gold_item, ks, rank_keys)
result["Rprec"] += ranking_metrics["Rprec"]
for k in ks:
if k > 0:
result["precision@{}".format(k)] += ranking_metrics[
"precision@{}".format(k)
]
if k > 1:
result["recall@{}".format(k)] += ranking_metrics["recall@{}".format(k)]
result["success_rate@{}".format(k)] += ranking_metrics[
"success_rate@{}".format(k)
]
if len(guess_dataset) > 0:
result["Rprec"] /= len(guess_dataset)
for k in ks:
if k > 0:
result["precision@{}".format(k)] /= len(guess_dataset)
if k > 1:
result["recall@{}".format(k)] /= len(guess_dataset)
result["success_rate@{}".format(k)] /= len(guess_dataset)
return result
def evaluate(gold, guess, ks, rank_keys):
pp = pprint.PrettyPrinter(indent=4)
gold_dataset = load_data(gold)
guess_dataset = load_data(guess)
# 0. validate input
gold_dataset, guess_dataset = validate_input(
gold_dataset, guess_dataset
)
# 1. get retrieval metrics
result = compute(gold_dataset, guess_dataset, ks, rank_keys)
pp.pprint(result)
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("guess", help="Guess KILT file")
parser.add_argument("gold", help="Gold KILT file")
parser.add_argument(
"--ks",
type=str,
required=False,
default="1,5,10,20",
help="Comma separated list of positive integers for recall@k and precision@k",
)
parser.add_argument(
"--rank_keys",
type=str,
required=False,
default="wikipedia_id",
help="Comma separated list of rank keys for recall@k and precision@k",
)
args = parser.parse_args()
args.ks = [int(k) for k in args.ks.split(",")]
args.rank_keys = [rank_key for rank_key in args.rank_keys.split(",")]
##########################################################################################
# Pyserini change:
# Download gold file if necessary
gold = args.gold
if not os.path.exists(args.gold):
gold = KiltQueryIterator.download_kilt_topics(gold)
##########################################################################################
evaluate(gold, args.guess, args.ks, args.rank_keys)