|
import os
|
|
os.environ["TF_ENABLE_ONEDNN_OPTS"] = '0'
|
|
|
|
from huggingface_hub import login
|
|
|
|
|
|
from typing import Union, Any, Dict
|
|
|
|
|
|
import argparse
|
|
import datasets
|
|
from transformers.utils import logging, check_min_version
|
|
from transformers.utils.versions import require_version
|
|
|
|
from retro_reader import RetroReader
|
|
from retro_reader.constants import EXAMPLE_FEATURES
|
|
import torch
|
|
|
|
|
|
check_min_version("4.13.0.dev0")
|
|
|
|
require_version("datasets>=1.8.0")
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def schema_integrate(example) -> Union[Dict, Any]:
|
|
title = example["title"]
|
|
question = example["question"]
|
|
context = example["context"]
|
|
guid = example["id"]
|
|
classtype = [""] * len(title)
|
|
dataset_name = source = ["squad_v2"] * len(title)
|
|
answers, is_impossible = [], []
|
|
for answer_examples in example["answers"]:
|
|
if answer_examples["text"]:
|
|
answers.append(answer_examples)
|
|
is_impossible.append(False)
|
|
else:
|
|
answers.append({"text": [""], "answer_start": [-1]})
|
|
is_impossible.append(True)
|
|
|
|
return {
|
|
"guid": guid,
|
|
"question": question,
|
|
"context": context,
|
|
"answers": answers,
|
|
"title": title,
|
|
"classtype": classtype,
|
|
"source": source,
|
|
"is_impossible": is_impossible,
|
|
"dataset": dataset_name,
|
|
}
|
|
|
|
|
|
|
|
def data_aug_for_multiple_answers(examples) -> Union[Dict, Any]:
|
|
result = {key: [] for key in examples.keys()}
|
|
|
|
def update(i, answers=None):
|
|
for key in result.keys():
|
|
if key == "answers" and answers is not None:
|
|
result[key].append(answers)
|
|
else:
|
|
result[key].append(examples[key][i])
|
|
|
|
for i, (answers, unanswerable) in enumerate(
|
|
zip(examples["answers"], examples["is_impossible"])
|
|
):
|
|
answerable = not unanswerable
|
|
assert (
|
|
len(answers["text"]) == len(answers["answer_start"]) or
|
|
answers["answer_start"][0] == -1
|
|
)
|
|
if answerable and len(answers["text"]) > 1:
|
|
for n_ans in range(len(answers["text"])):
|
|
ans = {
|
|
"text": [answers["text"][n_ans]],
|
|
"answer_start": [answers["answer_start"][n_ans]],
|
|
}
|
|
update(i, ans)
|
|
elif not answerable:
|
|
update(i, {"text": [], "answer_start": []})
|
|
else:
|
|
update(i)
|
|
|
|
return result
|
|
|
|
|
|
def main(args):
|
|
|
|
print("Loading SQuAD v2.0 dataset ...")
|
|
squad_v2 = datasets.load_dataset("squad_v2")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.debug:
|
|
squad_v2["validation"] = squad_v2["validation"].select(range(5))
|
|
|
|
print("Integrating into the schema used in this library ...")
|
|
squad_v2 = squad_v2.map(
|
|
schema_integrate,
|
|
batched=True,
|
|
remove_columns=squad_v2.column_names["train"],
|
|
features=EXAMPLE_FEATURES,
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading Retro Reader ...")
|
|
retro_reader = RetroReader.load(
|
|
config_file=args.configs,
|
|
device="cuda" if torch.cuda.is_available() else "cpu",
|
|
)
|
|
|
|
|
|
res = retro_reader.evaluate(squad_v2["validation"])
|
|
print(res)
|
|
logger.warning("Train retrospective reader Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--configs", "-c", type=str, default="configs/inference_electra_base.yaml", help="config file path")
|
|
parser.add_argument("--batch_size", "-b", type=int, default=1024, help="batch size")
|
|
parser.add_argument("--debug", "-d", action="store_true", help="debug mode")
|
|
args = parser.parse_args()
|
|
main(args) |