#!/usr/bin/env python # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse import json import os import re class InputExample: def __init__(self, paragraph, qa_list, label): self.paragraph = paragraph self.qa_list = qa_list self.label = label def get_examples(data_dir, set_type): """ Extract paragraph and question-answer list from each json file """ examples = [] levels = ["middle", "high"] set_type_c = set_type.split("-") if len(set_type_c) == 2: levels = [set_type_c[1]] set_type = set_type_c[0] for level in levels: cur_dir = os.path.join(data_dir, set_type, level) for filename in os.listdir(cur_dir): cur_path = os.path.join(cur_dir, filename) with open(cur_path, "r") as f: cur_data = json.load(f) answers = cur_data["answers"] options = cur_data["options"] questions = cur_data["questions"] context = cur_data["article"].replace("\n", " ") context = re.sub(r"\s+", " ", context) for i in range(len(answers)): label = ord(answers[i]) - ord("A") qa_list = [] question = questions[i] for j in range(4): option = options[i][j] if "_" in question: qa_cat = question.replace("_", option) else: qa_cat = " ".join([question, option]) qa_cat = re.sub(r"\s+", " ", qa_cat) qa_list.append(qa_cat) examples.append(InputExample(context, qa_list, label)) return examples def main(): """ Helper script to extract paragraphs questions and answers from RACE datasets. """ parser = argparse.ArgumentParser() parser.add_argument( "--input-dir", help="input directory for downloaded RACE dataset", ) parser.add_argument( "--output-dir", help="output directory for extracted data", ) args = parser.parse_args() if not os.path.exists(args.output_dir): os.makedirs(args.output_dir, exist_ok=True) for set_type in ["train", "dev", "test-middle", "test-high"]: examples = get_examples(args.input_dir, set_type) qa_file_paths = [ os.path.join(args.output_dir, set_type + ".input" + str(i + 1)) for i in range(4) ] qa_files = [open(qa_file_path, "w") for qa_file_path in qa_file_paths] outf_context_path = os.path.join(args.output_dir, set_type + ".input0") outf_label_path = os.path.join(args.output_dir, set_type + ".label") outf_context = open(outf_context_path, "w") outf_label = open(outf_label_path, "w") for example in examples: outf_context.write(example.paragraph + "\n") for i in range(4): qa_files[i].write(example.qa_list[i] + "\n") outf_label.write(str(example.label) + "\n") for f in qa_files: f.close() outf_label.close() outf_context.close() if __name__ == "__main__": main()