欧卫
'add_app_files'
58627fa
raw
history blame contribute delete
No virus
2.23 kB
"""
Divide a query set into two.
"""
import os
import math
import ujson
import random
from argparse import ArgumentParser
from collections import OrderedDict
from colbert.utils.utils import print_message
def main(args):
random.seed(12345)
"""
Load the queries
"""
Queries = OrderedDict()
print_message(f"#> Loading queries from {args.input}..")
with open(args.input) as f:
for line in f:
qid, query = line.strip().split('\t')
assert qid not in Queries
Queries[qid] = query
"""
Apply the splitting
"""
size_a = len(Queries) - args.holdout
size_b = args.holdout
size_a, size_b = max(size_a, size_b), min(size_a, size_b)
assert size_a > 0 and size_b > 0, (len(Queries), size_a, size_b)
print_message(f"#> Deterministically splitting the queries into ({size_a}, {size_b})-sized splits.")
keys = list(Queries.keys())
sample_b_indices = sorted(list(random.sample(range(len(keys)), size_b)))
sample_a_indices = sorted(list(set.difference(set(list(range(len(keys)))), set(sample_b_indices))))
assert len(sample_a_indices) == size_a
assert len(sample_b_indices) == size_b
sample_a = [keys[idx] for idx in sample_a_indices]
sample_b = [keys[idx] for idx in sample_b_indices]
"""
Write the output
"""
output_path_a = f'{args.input}.a'
output_path_b = f'{args.input}.b'
assert not os.path.exists(output_path_a), output_path_a
assert not os.path.exists(output_path_b), output_path_b
print_message(f"#> Writing the splits out to {output_path_a} and {output_path_b} ...")
for output_path, sample in [(output_path_a, sample_a), (output_path_b, sample_b)]:
with open(output_path, 'w') as f:
for qid in sample:
query = Queries[qid]
line = '\t'.join([qid, query]) + '\n'
f.write(line)
if __name__ == "__main__":
parser = ArgumentParser(description="queries_split.")
# Input Arguments.
parser.add_argument('--input', dest='input', required=True)
parser.add_argument('--holdout', dest='holdout', required=True, type=int)
args = parser.parse_args()
main(args)