File size: 4,490 Bytes
d6585f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file generate the training triple .tsv in `qid\tpos-docid\tneg-docid` format per line
# for both MS MARCO Document and Passage collections
#
# Usage:
# python scripts/msmarco_v2/generate_train_triplet.py \
#   -r v2_train_top100.txt \
#   -q v2_train_qrels.tsv \
#   -nneg 40 \
#   -o train-triple-ids.nneg-40.tsv
#   -topk 1000

import os
import random
import argparse
from collections import defaultdict
from tqdm import tqdm

def load_qrels(fn):
    """
    Loading trec format query relevance file into a dictionary
    :param fn: qrel file path
    :return: dict, in format {qid: {docid: label, ...}, ...}
    """
    qrels = defaultdict(dict)
    with open(fn, "r", encoding="utf-8") as f:
        for line in f:
            qid, _, docid, label = line.strip().split()
            qrels[qid][docid] = int(label)
    return qrels


def load_run(fn, topk):
    """
    Loading trec format runfile into a dictionary
    :param fn: runfile path
    :param topk: top results to include
    :return: dict, in format {qid: [docid, ...], ...}
    """
    run = defaultdict(list)
    with open(fn, "r", encoding="utf-8") as f:
        for line in f:
            qid, _, docid, _, score, _ = line.strip().split()
            run[qid].append((docid, float(score)))

    sorted_run = defaultdict(list)
    for query_id, docid_scores in tqdm(run.items()):
        docid_scores.sort(key=lambda x: x[1], reverse=True)
        doc_ids = [doc_id for doc_id, _ in docid_scores][:topk]
        sorted_run[query_id] = doc_ids

    return sorted_run


def open_as_write(fn):
    parent = os.path.dirname(fn)
    if parent != "":
        os.makedirs(parent, exist_ok=True)
    return open(fn, "w")


def main(args):
    assert args.output.endswith(".tsv")
    n_neg = args.n_neg_per_query
    require_pos_in_topk = args.require_pos_in_topk
    run = load_run(args.run_file, args.topk)
    qrels = load_qrels(args.qrel_file)
    n_not_in_topk, n_total = 0, len(qrels)

    with open_as_write(args.output) as fout:
        for n_processed, qid in tqdm(enumerate(qrels)):
            if qid not in run:
                continue

            top_k = run[qid]
            if require_pos_in_topk:
                pos_docids = [docid for docid in top_k if qrels[qid].get(docid, 0) > 0]
            else:
                pos_docids = [docid for docid in qrels[qid] if qrels[qid][docid] > 0]

            neg_docids = [docid for docid in top_k if qrels[qid].get(docid, 0) == 0]

            if len(pos_docids) == 0:
                n_not_in_topk += 1

            for pos_docid in pos_docids:
                sampled_neg_docids = random.choices(neg_docids, k=n_neg)
                lines = [f"{qid}\t{pos_docid}\t{neg_docid}\n" for neg_docid in sampled_neg_docids]
                fout.writelines(lines)

    print(f"Finished. {n_not_in_topk} out of {n_total} queries have no positive document in the runfile.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Generate MS MARCO V2 training triple .tsv')
    parser.add_argument('--run-file', '-r', required=True, help='MS MARCO V2 doc or passage train_top100.txt path.')
    parser.add_argument('--qrel-file', '-q', required=True, help='MS MARCO V2 doc or passsage train_qrels.tsv path.')
    parser.add_argument('--output', '-o', required=True, help='output training triple .tsv path')
    parser.add_argument('--n-neg-per-query', '-nneg', default=40, type=int, help='number of negative documents sampled for each query')
    parser.add_argument('--topk' , default=1000, type=int, help='top-k documents in the run file from which we sample negatives')
    parser.add_argument('--require-pos-in-topk', action='store_true', default=False, help='if specified, then only keep the positive documents if they appear in the given runfile')
    args = parser.parse_args()

    random.seed(123_456)
    main(args)