flynn-chen
all
97ec4dd
raw
history blame
13.7 kB
import itertools
import logging
from typing import Optional, Dict, Union
from nltk import sent_tokenize
import torch
from transformers import(
AutoModelForSeq2SeqLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
)
logger = logging.getLogger(__name__)
class QGPipeline:
"""Poor man's QG pipeline"""
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
ans_model: PreTrainedModel,
ans_tokenizer: PreTrainedTokenizer,
qg_format: str,
use_cuda: bool
):
self.model = model
self.tokenizer = tokenizer
self.ans_model = ans_model
self.ans_tokenizer = ans_tokenizer
self.qg_format = qg_format
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
self.model.to(self.device)
if self.ans_model is not self.model:
self.ans_model.to(self.device)
assert self.model.__class__.__name__ in ["T5ForConditionalGeneration", "BartForConditionalGeneration"]
if "T5ForConditionalGeneration" in self.model.__class__.__name__:
self.model_type = "t5"
else:
self.model_type = "bart"
def __call__(self, inputs: str):
inputs = " ".join(inputs.split())
sents, answers = self._extract_answers(inputs)
flat_answers = list(itertools.chain(*answers))
if len(flat_answers) == 0:
return []
if self.qg_format == "prepend":
qg_examples = self._prepare_inputs_for_qg_from_answers_prepend(inputs, answers)
else:
qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers)
qg_inputs = [example['source_text'] for example in qg_examples]
questions = self._generate_questions(qg_inputs)
output = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, questions)]
return output
def _generate_questions(self, inputs):
inputs = self._tokenize(inputs, padding=True, truncation=True)
outs = self.model.generate(
input_ids=inputs['input_ids'].to(self.device),
attention_mask=inputs['attention_mask'].to(self.device),
max_length=32,
num_beams=4,
)
questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
return questions
def _extract_answers(self, context):
sents, inputs = self._prepare_inputs_for_ans_extraction(context)
inputs = self._tokenize(inputs, padding=True, truncation=True)
outs = self.ans_model.generate(
input_ids=inputs['input_ids'].to(self.device),
attention_mask=inputs['attention_mask'].to(self.device),
max_length=32,
)
dec = [self.ans_tokenizer.decode(ids, skip_special_tokens=False) for ids in outs]
answers = [item.split('<sep>') for item in dec]
answers = [i[:-1] for i in answers]
return sents, answers
def _tokenize(self,
inputs,
padding=True,
truncation=True,
add_special_tokens=True,
max_length=512
):
inputs = self.tokenizer.batch_encode_plus(
inputs,
max_length=max_length,
add_special_tokens=add_special_tokens,
truncation=truncation,
padding="max_length" if padding else False,
pad_to_max_length=padding,
return_tensors="pt"
)
return inputs
def _prepare_inputs_for_ans_extraction(self, text):
sents = sent_tokenize(text)
inputs = []
for i in range(len(sents)):
source_text = "extract answers:"
for j, sent in enumerate(sents):
if i == j:
sent = "<hl> %s <hl>" % sent
source_text = "%s %s" % (source_text, sent)
source_text = source_text.strip()
if self.model_type == "t5":
source_text = source_text + " </s>"
inputs.append(source_text)
return sents, inputs
def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers):
inputs = []
for i, answer in enumerate(answers):
if len(answer) == 0: continue
for answer_text in answer:
sent = sents[i]
sents_copy = sents[:]
answer_text = answer_text.strip()
ans_start_idx = sent.index(answer_text)
sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_start_idx + len(answer_text): ]}"
sents_copy[i] = sent
source_text = " ".join(sents_copy)
source_text = f"generate question: {source_text}"
if self.model_type == "t5":
source_text = source_text + " </s>"
inputs.append({"answer": answer_text, "source_text": source_text})
return inputs
def _prepare_inputs_for_qg_from_answers_prepend(self, context, answers):
flat_answers = list(itertools.chain(*answers))
examples = []
for answer in flat_answers:
source_text = f"answer: {answer} context: {context}"
if self.model_type == "t5":
source_text = source_text + " </s>"
examples.append({"answer": answer, "source_text": source_text})
return examples
class MultiTaskQAQGPipeline(QGPipeline):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __call__(self, inputs: Union[Dict, str]):
if type(inputs) is str:
# do qg
return super().__call__(inputs)
else:
# do qa
return self._extract_answer(inputs["question"], inputs["context"])
def _prepare_inputs_for_qa(self, question, context):
source_text = f"question: {question} context: {context}"
if self.model_type == "t5":
source_text = source_text + " </s>"
return source_text
def _extract_answer(self, question, context):
source_text = self._prepare_inputs_for_qa(question, context)
inputs = self._tokenize([source_text], padding=False)
outs = self.model.generate(
input_ids=inputs['input_ids'].to(self.device),
attention_mask=inputs['attention_mask'].to(self.device),
max_length=16,
)
answer = self.tokenizer.decode(outs[0], skip_special_tokens=True)
return answer
class E2EQGPipeline:
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
use_cuda: bool
) :
self.model = model
self.tokenizer = tokenizer
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
self.model.to(self.device)
assert self.model.__class__.__name__ in ["T5ForConditionalGeneration", "BartForConditionalGeneration"]
if "T5ForConditionalGeneration" in self.model.__class__.__name__:
self.model_type = "t5"
else:
self.model_type = "bart"
self.default_generate_kwargs = {
"max_length": 256,
"num_beams": 4,
"length_penalty": 1.5,
"no_repeat_ngram_size": 3,
"early_stopping": True,
}
def __call__(self, context: str, **generate_kwargs):
inputs = self._prepare_inputs_for_e2e_qg(context)
# TODO: when overrding default_generate_kwargs all other arguments need to be passsed
# find a better way to do this
if not generate_kwargs:
generate_kwargs = self.default_generate_kwargs
input_length = inputs["input_ids"].shape[-1]
# max_length = generate_kwargs.get("max_length", 256)
# if input_length < max_length:
# logger.warning(
# "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
# max_length, input_length
# )
# )
outs = self.model.generate(
input_ids=inputs['input_ids'].to(self.device),
attention_mask=inputs['attention_mask'].to(self.device),
**generate_kwargs
)
prediction = self.tokenizer.decode(outs[0], skip_special_tokens=True)
questions = prediction.split("<sep>")
questions = [question.strip() for question in questions[:-1]]
return questions
def _prepare_inputs_for_e2e_qg(self, context):
source_text = f"generate questions: {context}"
if self.model_type == "t5":
source_text = source_text + " </s>"
inputs = self._tokenize([source_text], padding=False)
return inputs
def _tokenize(
self,
inputs,
padding=True,
truncation=True,
add_special_tokens=True,
max_length=512
):
inputs = self.tokenizer.batch_encode_plus(
inputs,
max_length=max_length,
add_special_tokens=add_special_tokens,
truncation=truncation,
padding="max_length" if padding else False,
pad_to_max_length=padding,
return_tensors="pt"
)
return inputs
SUPPORTED_TASKS = {
"question-generation": {
"impl": QGPipeline,
"default": {
"model": "valhalla/t5-small-qg-hl",
"ans_model": "valhalla/t5-small-qa-qg-hl",
}
},
"multitask-qa-qg": {
"impl": MultiTaskQAQGPipeline,
"default": {
"model": "valhalla/t5-small-qa-qg-hl",
}
},
"e2e-qg": {
"impl": E2EQGPipeline,
"default": {
"model": "valhalla/t5-small-e2e-qg",
}
}
}
def pipeline(
task: str,
model: Optional = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
qg_format: Optional[str] = "highlight",
ans_model: Optional = None,
ans_tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
use_cuda: Optional[bool] = True,
**kwargs,
):
# Retrieve the task
if task not in SUPPORTED_TASKS:
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
targeted_task = SUPPORTED_TASKS[task]
task_class = targeted_task["impl"]
# Use default model/config/tokenizer for the task if no model is provided
if model is None:
model = targeted_task["default"]["model"]
# Try to infer tokenizer from model or config name (if provided as str)
if tokenizer is None:
if isinstance(model, str):
tokenizer = model
else:
# Impossible to guest what is the right tokenizer here
raise Exception(
"Impossible to guess which tokenizer to use. "
"Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
)
# Instantiate tokenizer if needed
if isinstance(tokenizer, (str, tuple)):
if isinstance(tokenizer, tuple):
# For tuple we have (tokenizer name, {kwargs})
tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], **tokenizer[1])
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
# Instantiate model if needed
if isinstance(model, str):
model = AutoModelForSeq2SeqLM.from_pretrained(model)
if task == "question-generation":
if ans_model is None:
# load default ans model
ans_model = targeted_task["default"]["ans_model"]
ans_tokenizer = AutoTokenizer.from_pretrained(ans_model)
ans_model = AutoModelForSeq2SeqLM.from_pretrained(ans_model)
else:
# Try to infer tokenizer from model or config name (if provided as str)
if ans_tokenizer is None:
if isinstance(ans_model, str):
ans_tokenizer = ans_model
else:
# Impossible to guest what is the right tokenizer here
raise Exception(
"Impossible to guess which tokenizer to use. "
"Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
)
# Instantiate tokenizer if needed
if isinstance(ans_tokenizer, (str, tuple)):
if isinstance(ans_tokenizer, tuple):
# For tuple we have (tokenizer name, {kwargs})
ans_tokenizer = AutoTokenizer.from_pretrained(ans_tokenizer[0], **ans_tokenizer[1])
else:
ans_tokenizer = AutoTokenizer.from_pretrained(ans_tokenizer)
if isinstance(ans_model, str):
ans_model = AutoModelForSeq2SeqLM.from_pretrained(ans_model)
if task == "e2e-qg":
return task_class(model=model, tokenizer=tokenizer, use_cuda=use_cuda)
elif task == "question-generation":
return task_class(model=model, tokenizer=tokenizer, ans_model=ans_model, ans_tokenizer=ans_tokenizer, qg_format=qg_format, use_cuda=use_cuda)
else:
return task_class(model=model, tokenizer=tokenizer, ans_model=model, ans_tokenizer=tokenizer, qg_format=qg_format, use_cuda=use_cuda)