ir_chinese_medqa / utility /rankings /dev_subsample.py
欧卫
'add_app_files'
58627fa
raw
history blame contribute delete
No virus
1.35 kB
import os
import ujson
import random
from argparse import ArgumentParser
from colbert.utils.utils import print_message, create_directory, load_ranking, groupby_first_item
from utility.utils.qa_loaders import load_qas_
def main(args):
print_message("#> Loading all..")
qas = load_qas_(args.qas)
rankings = load_ranking(args.ranking)
qid2rankings = groupby_first_item(rankings)
print_message("#> Subsampling all..")
qas_sample = random.sample(qas, args.sample)
with open(args.output, 'w') as f:
for qid, *_ in qas_sample:
for items in qid2rankings[qid]:
items = [qid] + items
line = '\t'.join(map(str, items)) + '\n'
f.write(line)
print('\n\n')
print(args.output)
print("#> Done.")
if __name__ == "__main__":
random.seed(12345)
parser = ArgumentParser(description='Subsample the dev set.')
parser.add_argument('--qas', dest='qas', required=True, type=str)
parser.add_argument('--ranking', dest='ranking', required=True)
parser.add_argument('--output', dest='output', required=True)
parser.add_argument('--sample', dest='sample', default=1500, type=int)
args = parser.parse_args()
assert not os.path.exists(args.output), args.output
create_directory(os.path.dirname(args.output))
main(args)