flynn-chen
all
97ec4dd
# coding=utf-8
# Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace NLP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""SQUAD: The Stanford Question Answering Dataset."""
from __future__ import absolute_import, division, print_function
import json
import logging
import os
import nltk
nltk.download('punkt')
import nlp
_CITATION = """\
@article{2016arXiv160605250R,
author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},
Konstantin and {Liang}, Percy},
title = "{SQuAD: 100,000+ Questions for Machine Comprehension of Text}",
journal = {arXiv e-prints},
year = 2016,
eid = {arXiv:1606.05250},
pages = {arXiv:1606.05250},
archivePrefix = {arXiv},
eprint = {1606.05250},
}
"""
_DESCRIPTION = """\
Stanford Question Answering Dataset (SQuAD) is a reading comprehension \
dataset, consisting of questions posed by crowdworkers on a set of Wikipedia \
articles, where the answer to every question is a segment of text, or span, \
from the corresponding reading passage, or the question might be unanswerable.
"""
QG_FORMATS = [
"prepend",
"highlight",
"prepend_highlight",
]
class SquadMultitaskConfig(nlp.BuilderConfig):
"""BuilderConfig for SQUAD."""
def __init__(self, qg_format="highlight", **kwargs):
"""BuilderConfig for SQUAD.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(SquadMultitaskConfig, self).__init__(**kwargs)
self.qg_format = qg_format
class SquadMultitask(nlp.GeneratorBasedBuilder):
"""SQUAD: The Stanford Question Answering Dataset. Version 1.1."""
_URL = "https://rajpurkar.github.io/SQuAD-explorer/dataset/"
_DEV_FILE = "dev-v1.1.json"
_TRAINING_FILE = "train-v1.1.json"
BUILDER_CONFIGS = [
SquadMultitaskConfig(
name=f"{format_}_qg_format",
version=nlp.Version("1.0.0", "New split API (https://tensorflow.org/datasets/splits)"),
description="Plain text",
qg_format=format_
)
for format_ in QG_FORMATS
]
def _info(self):
return nlp.DatasetInfo(
description=_DESCRIPTION,
features=nlp.Features(
{
"source_text": nlp.Value("string"),
"target_text": nlp.Value("string"),
"task": nlp.Value("string"),
}
),
# No default supervised_keys (as we have to pass both question
# and context as input).
supervised_keys=None,
homepage="https://rajpurkar.github.io/SQuAD-explorer/",
citation=_CITATION,
)
def _split_generators(self, dl_manager):
urls_to_download = {
"train": os.path.join(self._URL, self._TRAINING_FILE),
"dev": os.path.join(self._URL, self._DEV_FILE),
}
downloaded_files = dl_manager.download_and_extract(urls_to_download)
return [
nlp.SplitGenerator(name=nlp.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]}),
nlp.SplitGenerator(name=nlp.Split.VALIDATION, gen_kwargs={"filepath": downloaded_files["dev"]}),
]
def _get_correct_alignement(self, context, answer):
""" Some original examples in SQuAD have indices wrong by 1 or 2 character. We test and fix this here. """
gold_text = answer['text']
start_idx = answer['answer_start']
end_idx = start_idx + len(gold_text)
if context[start_idx:end_idx] == gold_text:
return start_idx, end_idx # When the gold label position is good
elif context[start_idx-1:end_idx-1] == gold_text:
return start_idx-1, end_idx-1 # When the gold label is off by one character
elif context[start_idx-2:end_idx-2] == gold_text:
return start_idx-2, end_idx-2 # When the gold label is off by two character
else:
raise ValueError()
def process_qa_text(self, context, question, answer):
ans_gen_input = f"question: {question} context: {context}"
ans_gen_target = f"{answer}"
return {"source_text": ans_gen_input, "target_text": ans_gen_target, "task": "qa"}
def process_qg_text(self, context, question, answer):
answer_text = answer['text'].strip()
if self.config.qg_format == "prepend":
que_gen_input = f"answer: {answer_text} context: {context}"
elif self.config.qg_format == "highlight":
start_pos, end_pos = self._get_correct_alignement(context, answer)
que_gen_input = f"generate question: {context[:start_pos]} {{hl_token}} {answer_text} {{hl_token}} {context[end_pos:]}"
else:
start_pos, end_pos = self._get_correct_alignement(context, answer)
que_gen_input = f"answer: {answer_text} context: {context[:start_pos]} {{hl_token}} {answer_text} {{hl_token}} {context[end_pos:]}"
que_gen_target = f"{question}"
return {"source_text": que_gen_input, "target_text": que_gen_target, "task": "qg"}
def process_e2e_qg(self, paragraph):
source_text = f"generate questions: {paragraph['context'].strip()}"
questions = [qas['question'].strip() for qas in paragraph['qas']]
target_text = " {sep_token} ".join(questions)
target_text = f"{target_text} {{sep_token}}"
return {"source_text": source_text, "target_text": target_text, "task": "e2e_qg"}
def process_ans_ext(self, paragraph):
context = paragraph['context'].strip()
# split into sentences
sents = nltk.sent_tokenize(context)
# get positions of the sentences
positions = []
for i, sent in enumerate(sents):
if i == 0:
start, end = 0, len(sent)
else:
start, end = (prev_end + 1), (prev_end + len(sent) + 1)
prev_end = end
positions.append({'start': start, 'end': end})
# get answers
answers = [qa['answers'][0] for qa in paragraph['qas']]
# get list of answers for each sentence
sent_answers = []
for pos, sent in zip(positions, sents):
target_answers = []
for ans in answers:
if ans['answer_start'] in range(pos['start'], pos['end']):
target_answers.append(ans['text'].strip())
sent_answers.append(target_answers)
# build inputs and targets
examples = []
for i, ans in enumerate(sent_answers):
context = "extract answers:"
if len(ans) == 0: continue
ans = list(set(ans))
for j, sent in enumerate(sents):
if i == j:
sent = "{hl_token} %s {hl_token}" % sent
context = "%s %s" % (context, sent)
context = context.strip()
input_text = context
target_text = " {sep_token} ".join(ans) + " {sep_token}"
examples.append({'source_text': input_text, "target_text": target_text, "task": "ans_ext"})
return examples
def _generate_examples(self, filepath):
"""This function returns the examples in the raw (text) form."""
logging.info("generating examples from = %s", filepath)
count = 0
tasks = ['qa', 'qg', 'ans_ext', 'e2e_qg']
with open(filepath) as f:
squad = json.load(f)
for article in squad["data"]:
title = article.get("title", "").strip()
for paragraph in article["paragraphs"]:
context = paragraph["context"].strip()
if 'ans_ext' in tasks:
ans_ext_examples = self.process_ans_ext(paragraph)
for example in ans_ext_examples:
yield count, example
count += 1
if 'e2e_qg' in tasks:
yield count, self.process_e2e_qg(paragraph)
count += 1
for qa in paragraph["qas"]:
question = qa["question"].strip()
id_ = qa["id"]
answers = [answer["text"].strip() for answer in qa["answers"]]
for task in tasks:
if task == 'qa':
yield count, self.process_qa_text(context, question, answers[0])
count += 1
if task == 'qg':
yield count, self.process_qg_text(context, question, qa["answers"][0])
count += 1