from colbert.utils.utils import print_message from utility.utils.dpr import DPR_normalize, has_answer def tokenize_all_answers(args): qid, question, answers = args return qid, question, [DPR_normalize(ans) for ans in answers] def assign_label_to_passage(args): idx, (qid, pid, rank, passage, tokenized_answers) = args if idx % (1*1000*1000) == 0: print(idx) return qid, pid, rank, has_answer(tokenized_answers, passage) def check_sizes(qid2answers, qid2rankings): num_judged_queries = len(qid2answers) num_ranked_queries = len(qid2rankings) print_message('num_judged_queries =', num_judged_queries) print_message('num_ranked_queries =', num_ranked_queries) if num_judged_queries != num_ranked_queries: assert num_ranked_queries <= num_judged_queries print('\n\n') print_message('[WARNING] num_judged_queries != num_ranked_queries') print('\n\n') return num_judged_queries, num_ranked_queries def compute_and_write_labels(output_path, qid2answers, qid2rankings): cutoffs = [1, 5, 10, 20, 30, 50, 100, 1000, 'all'] success = {cutoff: 0.0 for cutoff in cutoffs} counts = {cutoff: 0.0 for cutoff in cutoffs} with open(output_path, 'w') as f: for qid in qid2answers: if qid not in qid2rankings: continue prev_rank = 0 # ranks should start at one (i.e., and not zero) labels = [] for pid, rank, label in qid2rankings[qid]: assert rank == prev_rank+1, (qid, pid, (prev_rank, rank)) prev_rank = rank labels.append(label) line = '\t'.join(map(str, [qid, pid, rank, int(label)])) + '\n' f.write(line) for cutoff in cutoffs: if cutoff != 'all': success[cutoff] += sum(labels[:cutoff]) > 0 counts[cutoff] += sum(labels[:cutoff]) else: success[cutoff] += sum(labels) > 0 counts[cutoff] += sum(labels) return success, counts # def dump_metrics(f, nqueries, cutoffs, success, counts): # for cutoff in cutoffs: # success_log = "#> P@{} = {}".format(cutoff, success[cutoff] / nqueries) # counts_log = "#> D@{} = {}".format(cutoff, counts[cutoff] / nqueries) # print('\n'.join([success_log, counts_log]) + '\n') # f.write('\n'.join([success_log, counts_log]) + '\n\n')