# coding:utf-8 """ Filename: mt5.py Author: @DvdNss Created on 12/30/2021 """ from typing import List from pytorch_lightning import LightningModule from transformers import MT5ForConditionalGeneration, AutoTokenizer class MT5(LightningModule): """ Google MT5 transformer class. """ def __init__(self, model_name_or_path: str = None): """ Initialize module. :param model_name_or_path: model name """ super().__init__() # Load model and tokenizer self.save_hyperparameters() self.model = MT5ForConditionalGeneration.from_pretrained( model_name_or_path) if model_name_or_path is not None else None self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) if model_name_or_path is not None else None def forward(self, **inputs): """ Forward inputs. :param inputs: dictionary of inputs (input_ids, attention_mask, labels) """ return self.model(**inputs) def qa(self, batch: List[dict], max_length: int = 512, **kwargs): """ Question answering prediction. :param batch: batch of dict {question: q, context: c} :param max_length: max length of output """ # Transform inputs inputs = [f"question: {context['question']} context: {context['context']}" for context in batch] # Predict outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs) return outputs def qg(self, batch: List[str] = None, max_length: int = 512, **kwargs): """ Question generation prediction. :param batch: batch of context with highlighted elements :param max_length: max length of output """ # Transform inputs inputs = [f"generate: {context}" for context in batch] # Predict outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs) return outputs def ae(self, batch: List[str], max_length: int = 512, **kwargs): """ Answer extraction prediction. :param batch: list of context :param max_length: max length of output """ # Transform inputs inputs = [f"extract: {context}" for context in batch] # Predict outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs) return outputs def multitask(self, batch: List[str], max_length: int = 512, **kwargs): """ Answer extraction + question generation + question answering. :param batch: list of context :param max_length: max length of outputs """ # Build output dict dict_batch = {'context': [context for context in batch], 'answers': [], 'questions': [], 'answers_bis': []} # Iterate over context for context in batch: answers = self.ae(batch=[context], max_length=max_length, **kwargs)[0] answers = answers.split('') answers = [ans.strip() for ans in answers if ans != ' '] dict_batch['answers'].append(answers) for_qg = [f"{context.replace(ans, f' {ans} ')}" for ans in answers] questions = self.qg(batch=for_qg, max_length=max_length, **kwargs) dict_batch['questions'].append(questions) new_answers = self.qa([{'context': context, 'question': question} for question in questions], max_length=max_length, **kwargs) dict_batch['answers_bis'].append(new_answers) return dict_batch def predict(self, inputs, max_length, **kwargs): """ Inference processing. :param inputs: list of inputs :param max_length: max_length of outputs """ # Tokenize inputs inputs = self.tokenizer(inputs, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt") # Retrieve input_ids and attention_mask input_ids = inputs.input_ids.to(self.model.device) attention_mask = inputs.attention_mask.to(self.model.device) # Predict outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, **kwargs) # Decode outputs predictions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) return predictions