Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace NLP Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Lint as: python3 | |
"""SQUAD: The Stanford Question Answering Dataset.""" | |
from __future__ import absolute_import, division, print_function | |
import json | |
import logging | |
import os | |
import nltk | |
nltk.download('punkt') | |
import nlp | |
_CITATION = """\ | |
@article{2016arXiv160605250R, | |
author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev}, | |
Konstantin and {Liang}, Percy}, | |
title = "{SQuAD: 100,000+ Questions for Machine Comprehension of Text}", | |
journal = {arXiv e-prints}, | |
year = 2016, | |
eid = {arXiv:1606.05250}, | |
pages = {arXiv:1606.05250}, | |
archivePrefix = {arXiv}, | |
eprint = {1606.05250}, | |
} | |
""" | |
_DESCRIPTION = """\ | |
Stanford Question Answering Dataset (SQuAD) is a reading comprehension \ | |
dataset, consisting of questions posed by crowdworkers on a set of Wikipedia \ | |
articles, where the answer to every question is a segment of text, or span, \ | |
from the corresponding reading passage, or the question might be unanswerable. | |
""" | |
QG_FORMATS = [ | |
"prepend", | |
"highlight", | |
"prepend_highlight", | |
] | |
class SquadMultitaskConfig(nlp.BuilderConfig): | |
"""BuilderConfig for SQUAD.""" | |
def __init__(self, qg_format="highlight", **kwargs): | |
"""BuilderConfig for SQUAD. | |
Args: | |
**kwargs: keyword arguments forwarded to super. | |
""" | |
super(SquadMultitaskConfig, self).__init__(**kwargs) | |
self.qg_format = qg_format | |
class SquadMultitask(nlp.GeneratorBasedBuilder): | |
"""SQUAD: The Stanford Question Answering Dataset. Version 1.1.""" | |
_URL = "https://rajpurkar.github.io/SQuAD-explorer/dataset/" | |
_DEV_FILE = "dev-v1.1.json" | |
_TRAINING_FILE = "train-v1.1.json" | |
BUILDER_CONFIGS = [ | |
SquadMultitaskConfig( | |
name=f"{format_}_qg_format", | |
version=nlp.Version("1.0.0", "New split API (https://tensorflow.org/datasets/splits)"), | |
description="Plain text", | |
qg_format=format_ | |
) | |
for format_ in QG_FORMATS | |
] | |
def _info(self): | |
return nlp.DatasetInfo( | |
description=_DESCRIPTION, | |
features=nlp.Features( | |
{ | |
"source_text": nlp.Value("string"), | |
"target_text": nlp.Value("string"), | |
"task": nlp.Value("string"), | |
} | |
), | |
# No default supervised_keys (as we have to pass both question | |
# and context as input). | |
supervised_keys=None, | |
homepage="https://rajpurkar.github.io/SQuAD-explorer/", | |
citation=_CITATION, | |
) | |
def _split_generators(self, dl_manager): | |
urls_to_download = { | |
"train": os.path.join(self._URL, self._TRAINING_FILE), | |
"dev": os.path.join(self._URL, self._DEV_FILE), | |
} | |
downloaded_files = dl_manager.download_and_extract(urls_to_download) | |
return [ | |
nlp.SplitGenerator(name=nlp.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]}), | |
nlp.SplitGenerator(name=nlp.Split.VALIDATION, gen_kwargs={"filepath": downloaded_files["dev"]}), | |
] | |
def _get_correct_alignement(self, context, answer): | |
""" Some original examples in SQuAD have indices wrong by 1 or 2 character. We test and fix this here. """ | |
gold_text = answer['text'] | |
start_idx = answer['answer_start'] | |
end_idx = start_idx + len(gold_text) | |
if context[start_idx:end_idx] == gold_text: | |
return start_idx, end_idx # When the gold label position is good | |
elif context[start_idx-1:end_idx-1] == gold_text: | |
return start_idx-1, end_idx-1 # When the gold label is off by one character | |
elif context[start_idx-2:end_idx-2] == gold_text: | |
return start_idx-2, end_idx-2 # When the gold label is off by two character | |
else: | |
raise ValueError() | |
def process_qa_text(self, context, question, answer): | |
ans_gen_input = f"question: {question} context: {context}" | |
ans_gen_target = f"{answer}" | |
return {"source_text": ans_gen_input, "target_text": ans_gen_target, "task": "qa"} | |
def process_qg_text(self, context, question, answer): | |
answer_text = answer['text'].strip() | |
if self.config.qg_format == "prepend": | |
que_gen_input = f"answer: {answer_text} context: {context}" | |
elif self.config.qg_format == "highlight": | |
start_pos, end_pos = self._get_correct_alignement(context, answer) | |
que_gen_input = f"generate question: {context[:start_pos]} {{hl_token}} {answer_text} {{hl_token}} {context[end_pos:]}" | |
else: | |
start_pos, end_pos = self._get_correct_alignement(context, answer) | |
que_gen_input = f"answer: {answer_text} context: {context[:start_pos]} {{hl_token}} {answer_text} {{hl_token}} {context[end_pos:]}" | |
que_gen_target = f"{question}" | |
return {"source_text": que_gen_input, "target_text": que_gen_target, "task": "qg"} | |
def process_e2e_qg(self, paragraph): | |
source_text = f"generate questions: {paragraph['context'].strip()}" | |
questions = [qas['question'].strip() for qas in paragraph['qas']] | |
target_text = " {sep_token} ".join(questions) | |
target_text = f"{target_text} {{sep_token}}" | |
return {"source_text": source_text, "target_text": target_text, "task": "e2e_qg"} | |
def process_ans_ext(self, paragraph): | |
context = paragraph['context'].strip() | |
# split into sentences | |
sents = nltk.sent_tokenize(context) | |
# get positions of the sentences | |
positions = [] | |
for i, sent in enumerate(sents): | |
if i == 0: | |
start, end = 0, len(sent) | |
else: | |
start, end = (prev_end + 1), (prev_end + len(sent) + 1) | |
prev_end = end | |
positions.append({'start': start, 'end': end}) | |
# get answers | |
answers = [qa['answers'][0] for qa in paragraph['qas']] | |
# get list of answers for each sentence | |
sent_answers = [] | |
for pos, sent in zip(positions, sents): | |
target_answers = [] | |
for ans in answers: | |
if ans['answer_start'] in range(pos['start'], pos['end']): | |
target_answers.append(ans['text'].strip()) | |
sent_answers.append(target_answers) | |
# build inputs and targets | |
examples = [] | |
for i, ans in enumerate(sent_answers): | |
context = "extract answers:" | |
if len(ans) == 0: continue | |
ans = list(set(ans)) | |
for j, sent in enumerate(sents): | |
if i == j: | |
sent = "{hl_token} %s {hl_token}" % sent | |
context = "%s %s" % (context, sent) | |
context = context.strip() | |
input_text = context | |
target_text = " {sep_token} ".join(ans) + " {sep_token}" | |
examples.append({'source_text': input_text, "target_text": target_text, "task": "ans_ext"}) | |
return examples | |
def _generate_examples(self, filepath): | |
"""This function returns the examples in the raw (text) form.""" | |
logging.info("generating examples from = %s", filepath) | |
count = 0 | |
tasks = ['qa', 'qg', 'ans_ext', 'e2e_qg'] | |
with open(filepath) as f: | |
squad = json.load(f) | |
for article in squad["data"]: | |
title = article.get("title", "").strip() | |
for paragraph in article["paragraphs"]: | |
context = paragraph["context"].strip() | |
if 'ans_ext' in tasks: | |
ans_ext_examples = self.process_ans_ext(paragraph) | |
for example in ans_ext_examples: | |
yield count, example | |
count += 1 | |
if 'e2e_qg' in tasks: | |
yield count, self.process_e2e_qg(paragraph) | |
count += 1 | |
for qa in paragraph["qas"]: | |
question = qa["question"].strip() | |
id_ = qa["id"] | |
answers = [answer["text"].strip() for answer in qa["answers"]] | |
for task in tasks: | |
if task == 'qa': | |
yield count, self.process_qa_text(context, question, answers[0]) | |
count += 1 | |
if task == 'qg': | |
yield count, self.process_qg_text(context, question, qa["answers"][0]) | |
count += 1 | |