File size: 4,829 Bytes
e9d3f14 540ccda e9d3f14 540ccda c2a8229 e9d3f14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
)
from nltk import sent_tokenize
# Answer Extraction Handler
class AEHandler:
def __init__(self, model, tokenizer):
self.model = AutoModelForSeq2SeqLM.from_pretrained(model)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
def __call__(self, context):
return self.inference(self.preprocess(context))
def preprocess(self, context):
sents = sent_tokenize(context)
inputs = []
for i in range(len(sents)):
source_text = "extract answers:"
for j, sent in enumerate(sents):
if i == j:
sent = "<hl> %s <hl>" % sent
source_text = "%s %s" % (source_text, sent)
source_text = source_text.strip()
source_text = source_text + " </s>"
inputs.append(source_text)
tokenized_inputs = self.tokenizer.batch_encode_plus(
inputs,
max_length=512,
add_special_tokens=True,
truncation=True,
padding="max_length",
pad_to_max_length=True,
return_tensors="pt"
)
return tokenized_inputs
def inference(self, inputs):
outs = self.model.generate(
input_ids=inputs['input_ids'].to(self.device),
attention_mask=inputs['attention_mask'].to(self.device),
max_length=32)
dec = [self.tokenizer.decode(ids, skip_special_tokens=False).replace('<pad> ', '').strip() for ids in outs]
answers = [item.split('<sep>')[:-1] for item in dec]
return answers
def postprocess(self, outputs):
return outputs
# Question Generation Handler
class QGHandler:
def __init__(self, model, tokenizer):
self.model = AutoModelForSeq2SeqLM.from_pretrained(model)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
def __call__(self, answers, context):
tokenized_inputs = self.preprocess(answers, context)
return self.inference(tokenized_inputs)
def preprocess(self, answers, context):
# prepare inputs for question generation from answers
sents = sent_tokenize(context)
qg_examples = []
for i, answer in enumerate(answers):
if len(answer) == 0: continue
for answer_text in answer:
sent = sents[i]
sents_copy = sents[:]
answer_text = answer_text.strip()
try:
ans_start_idx = sent.index(answer_text)
except:
continue
sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_start_idx + len(answer_text): ]}"
sents_copy[i] = sent
source_text = " ".join(sents_copy)
source_text = f"generate question: {source_text}"
#if self.model_type == "t5":
source_text = source_text + " </s>"
qg_examples.append({"answer": answer_text, "source_text": source_text})
# question generation inputs
qg_inputs = [example['source_text'] for example in qg_examples]
tokenized_inputs = self.tokenizer.batch_encode_plus(
qg_inputs,
max_length=512,
add_special_tokens=True,
truncation=True,
padding="max_length",
pad_to_max_length=True,
return_tensors="pt"
)
self.qg_examples = qg_examples
return tokenized_inputs
def inference(self, inputs):
outs = self.model.generate(
input_ids=inputs['input_ids'].to(self.device),
attention_mask=inputs['attention_mask'].to(self.device),
max_length=32,
num_beams=4,
)
questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
return questions
def postprocess(self, questions):
outputs = [{'question': que, 'answer': example['answer']} for example, que in zip(self.qg_examples, questions)]
return outputs
# Question-Answer Generation Pipeline
class Pipeline:
def __init__(self, q_model=None, q_tokenizer=None, a_model=None, a_tokenizer=None):
self.q_model = q_model if q_model is not None else "valhalla/t5-small-qg-hl"
self.q_tokenizer = q_tokenizer if q_tokenizer is not None else "valhalla/t5-small-qg-hl"
self.a_model = a_model if a_model is not None else "valhalla/t5-small-qa-qg-hl"
self.a_tokenizer = a_tokenizer if a_tokenizer is not None else "valhalla/t5-small-qa-qg-hl"
self.answer_extractor = AEHandler(self.a_model, self.a_tokenizer)
self.question_generator = QGHandler(self.q_model, self.q_tokenizer)
def __call__(self, context):
answers = self.answer_extractor(context)
questions = self.question_generator(answers, context)
return self.question_generator.postprocess(questions)
|