|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
from typing import Dict, List, Any |
|
import itertools |
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path=""): |
|
|
|
|
|
|
|
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(path) |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
|
|
|
def __call__(self, inputs: str): |
|
if len(inputs) == 0: return [] |
|
inputs = " ".join(inputs.split()) |
|
sents, answers = self._extract_answers(inputs) |
|
flat_answers = list(itertools.chain(*answers)) |
|
|
|
if len(flat_answers) == 0: |
|
return [] |
|
|
|
qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers) |
|
|
|
qg_inputs = [example['source_text'] for example in qg_examples] |
|
questions = self._generate_questions(qg_inputs) |
|
output = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, questions)] |
|
output = self.clean_generated_QAs(output) |
|
return output |
|
|
|
def _extract_answers(self, context): |
|
print("_extract_answers") |
|
sents, inputs = self._prepare_inputs_for_ans_extraction(context) |
|
inputs = self._tokenize(inputs, padding=True, truncation=True) |
|
|
|
outs = self.model.generate( |
|
input_ids=inputs['input_ids'].to(self.device), |
|
attention_mask=inputs['attention_mask'].to(self.device), |
|
max_length=32, |
|
) |
|
|
|
dec = [self.tokenizer.decode(ids, skip_special_tokens=False) for ids in outs] |
|
answers = [item.split('<sep>') for item in dec] |
|
answers = [i[:-1] for i in answers] |
|
|
|
return sents, answers |
|
|
|
|
|
def _prepare_inputs_for_ans_extraction(self, text): |
|
print("_prepare_inputs_for_ans_extraction") |
|
sents = sent_tokenize(text) |
|
|
|
inputs = [] |
|
for i in range(len(sents)): |
|
source_text = "extract answers:" |
|
for j, sent in enumerate(sents): |
|
if i == j: |
|
sent = "<hl> %s <hl>" % sent |
|
source_text = "%s %s" % (source_text, sent) |
|
source_text = source_text.strip() |
|
|
|
if self.model_type == "t5": |
|
source_text = source_text + " </s>" |
|
inputs.append(source_text) |
|
|
|
return sents, inputs |
|
|
|
def _tokenize(self, |
|
inputs, |
|
padding=True, |
|
truncation=True, |
|
add_special_tokens=True, |
|
max_length=512 |
|
): |
|
inputs = self.tokenizer.batch_encode_plus( |
|
inputs, |
|
max_length=max_length, |
|
add_special_tokens=add_special_tokens, |
|
truncation=truncation, |
|
padding="max_length" if padding else False, |
|
pad_to_max_length=padding, |
|
return_tensors="pt" |
|
) |
|
return inputs |
|
|
|
def _generate_questions(self, inputs): |
|
print("_generate_questions") |
|
inputs = self._tokenize(inputs, padding=True, truncation=True) |
|
|
|
outs = self.model.generate( |
|
input_ids=inputs['input_ids'].to(self.device), |
|
attention_mask=inputs['attention_mask'].to(self.device), |
|
max_length=32, |
|
num_beams=4, |
|
) |
|
|
|
questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs] |
|
return questions |
|
|
|
|
|
|
|
def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers): |
|
print("_prepare_inputs_for_qg_from_answers_hl") |
|
inputs = [] |
|
for i, answer in enumerate(answers): |
|
if len(answer) == 0: continue |
|
for answer_text in answer: |
|
sent = sents[i] |
|
sents_copy = sents[:] |
|
answer_text = self.remove_pad(answer_text) |
|
answer_text = answer_text.strip() |
|
print("Answer", answer) |
|
print("Answer text", answer_text) |
|
|
|
try: |
|
ans_start_idx = sent.lower().index(answer_text.lower()) |
|
except ValueError: |
|
|
|
continue |
|
|
|
sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_start_idx + len(answer_text): ]}" |
|
sents_copy[i] = sent |
|
|
|
source_text = " ".join(sents_copy) |
|
source_text = f"generate question: {source_text}" |
|
if self.model_type == "t5": |
|
source_text = source_text + " </s>" |
|
|
|
inputs.append({"answer": answer_text, "source_text": source_text}) |
|
|
|
return inputs |
|
|
|
def clean_generated_QAs(self, generated_QAs): |
|
clean_QAs = [] |
|
answers_used = set() |
|
|
|
for qa in generated_QAs: |
|
if qa['answer'] in answers_used: |
|
break |
|
answers_used.add(qa['answer']) |
|
clean_QAs.append(qa) |
|
return clean_QAs |
|
|
|
def remove_pad(self, str): |
|
if "<pad>" in str: |
|
return str.replace("<pad>", "") |
|
return str |
|
|
|
|