brian-challenge / mt5.py
Christian Koch
question generator
0df07e9
raw
history blame
4.54 kB
# coding:utf-8
"""
Filename: mt5.py
Author: @DvdNss
Created on 12/30/2021
"""
from typing import List
from pytorch_lightning import LightningModule
from transformers import MT5ForConditionalGeneration, AutoTokenizer
class MT5(LightningModule):
"""
Google MT5 transformer class.
"""
def __init__(self, model_name_or_path: str = None):
"""
Initialize module.
:param model_name_or_path: model name
"""
super().__init__()
# Load model and tokenizer
self.save_hyperparameters()
self.model = MT5ForConditionalGeneration.from_pretrained(
model_name_or_path) if model_name_or_path is not None else None
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
use_fast=True) if model_name_or_path is not None else None
def forward(self, **inputs):
"""
Forward inputs.
:param inputs: dictionary of inputs (input_ids, attention_mask, labels)
"""
return self.model(**inputs)
def qa(self, batch: List[dict], max_length: int = 512, **kwargs):
"""
Question answering prediction.
:param batch: batch of dict {question: q, context: c}
:param max_length: max length of output
"""
# Transform inputs
inputs = [f"question: {context['question']} context: {context['context']}" for context in batch]
# Predict
outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)
return outputs
def qg(self, batch: List[str] = None, max_length: int = 512, **kwargs):
"""
Question generation prediction.
:param batch: batch of context with highlighted elements
:param max_length: max length of output
"""
# Transform inputs
inputs = [f"generate: {context}" for context in batch]
# Predict
outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)
return outputs
def ae(self, batch: List[str], max_length: int = 512, **kwargs):
"""
Answer extraction prediction.
:param batch: list of context
:param max_length: max length of output
"""
# Transform inputs
inputs = [f"extract: {context}" for context in batch]
# Predict
outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)
return outputs
def multitask(self, batch: List[str], max_length: int = 512, **kwargs):
"""
Answer extraction + question generation + question answering.
:param batch: list of context
:param max_length: max length of outputs
"""
# Build output dict
dict_batch = {'context': [context for context in batch], 'answers': [], 'questions': [], 'answers_bis': []}
# Iterate over context
for context in batch:
answers = self.ae(batch=[context], max_length=max_length, **kwargs)[0]
answers = answers.split('<sep>')
answers = [ans.strip() for ans in answers if ans != ' ']
dict_batch['answers'].append(answers)
for_qg = [f"{context.replace(ans, f'<hl> {ans} <hl> ')}" for ans in answers]
questions = self.qg(batch=for_qg, max_length=max_length, **kwargs)
dict_batch['questions'].append(questions)
new_answers = self.qa([{'context': context, 'question': question} for question in questions],
max_length=max_length, **kwargs)
dict_batch['answers_bis'].append(new_answers)
return dict_batch
def predict(self, inputs, max_length, **kwargs):
"""
Inference processing.
:param inputs: list of inputs
:param max_length: max_length of outputs
"""
# Tokenize inputs
inputs = self.tokenizer(inputs, max_length=max_length, padding='max_length', truncation=True,
return_tensors="pt")
# Retrieve input_ids and attention_mask
input_ids = inputs.input_ids.to(self.model.device)
attention_mask = inputs.attention_mask.to(self.model.device)
# Predict
outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=max_length,
**kwargs)
# Decode outputs
predictions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return predictions