|
from typing import List, Dict, Tuple |
|
from transformers import T5TokenizerFast, T5ForConditionalGeneration |
|
import string |
|
from typing import List |
|
|
|
|
|
MODEL_NAME = 't5-small' |
|
SOURCE_MAX_TOKEN_LEN = 300 |
|
TARGET_MAX_TOKEN_LEN = 80 |
|
SEP_TOKEN = '<sep>' |
|
TOKENIZER_LEN = 32101 |
|
|
|
|
|
class QuestionAnswerGenerator(): |
|
|
|
def __init__(self): |
|
self.tokenizer = T5TokenizerFast.from_pretrained(MODEL_NAME) |
|
self.tokenizer.add_tokens(SEP_TOKEN) |
|
self.tokenizer_len = len(self.tokenizer) |
|
self.model = T5ForConditionalGeneration.from_pretrained("fahmiaziz/QAModel") |
|
|
|
def generate(self, answer: str, context: str) -> str: |
|
|
|
model_output = self._model_predict(answer, context) |
|
generated_answer, generated_question = model_output.split(SEP_TOKEN) |
|
return generated_question |
|
|
|
def generate_qna(self, context: str) -> Tuple[str, str]: |
|
|
|
answer_mask = '[MASK]' |
|
model_output = self._model_predict(answer_mask, context) |
|
|
|
qna_pair = model_output.split(SEP_TOKEN) |
|
|
|
if len(qna_pair) < 2: |
|
generated_answer = '' |
|
generated_question = qna_pair[0] |
|
else: |
|
generated_answer = qna_pair[0] |
|
generated_question = qna_pair[1] |
|
|
|
return generated_answer, generated_question |
|
|
|
def _model_predict(self, answer: str, context: str) -> str: |
|
source_encoding = self.tokenizer( |
|
'{} {} {}'.format(answer, SEP_TOKEN, context), |
|
max_length=SOURCE_MAX_TOKEN_LEN, |
|
padding='max_length', |
|
truncation=True, |
|
return_attention_mask=True, |
|
add_special_tokens=True, |
|
return_tensors='pt' |
|
) |
|
|
|
generated_ids = self.model.generate( |
|
input_ids=source_encoding['input_ids'], |
|
attention_mask=source_encoding['attention_mask'], |
|
num_beams=16, |
|
max_length=TARGET_MAX_TOKEN_LEN, |
|
repetition_penalty=2.5, |
|
length_penalty=1.0, |
|
early_stopping=True, |
|
use_cache=True |
|
) |
|
|
|
preds = { |
|
self.tokenizer.decode(generated_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
for generated_id in generated_ids |
|
} |
|
|
|
return ''.join(preds) |
|
|