import os import ujson import random from argparse import ArgumentParser from colbert.utils.utils import print_message, create_directory, load_ranking, groupby_first_item from utility.utils.qa_loaders import load_qas_ def main(args): print_message("#> Loading all..") qas = load_qas_(args.qas) rankings = load_ranking(args.ranking) qid2rankings = groupby_first_item(rankings) print_message("#> Subsampling all..") qas_sample = random.sample(qas, args.sample) with open(args.output, 'w') as f: for qid, *_ in qas_sample: for items in qid2rankings[qid]: items = [qid] + items line = '\t'.join(map(str, items)) + '\n' f.write(line) print('\n\n') print(args.output) print("#> Done.") if __name__ == "__main__": random.seed(12345) parser = ArgumentParser(description='Subsample the dev set.') parser.add_argument('--qas', dest='qas', required=True, type=str) parser.add_argument('--ranking', dest='ranking', required=True) parser.add_argument('--output', dest='output', required=True) parser.add_argument('--sample', dest='sample', default=1500, type=int) args = parser.parse_args() assert not os.path.exists(args.output), args.output create_directory(os.path.dirname(args.output)) main(args)