Spaces:
Runtime error
Runtime error
# coding:utf-8 | |
""" | |
Filename: mt5.py | |
Author: @DvdNss | |
Created on 12/30/2021 | |
""" | |
from typing import List | |
from pytorch_lightning import LightningModule | |
from transformers import MT5ForConditionalGeneration, AutoTokenizer | |
class MT5(LightningModule): | |
""" | |
Google MT5 transformer class. | |
""" | |
def __init__(self, model_name_or_path: str = None): | |
""" | |
Initialize module. | |
:param model_name_or_path: model name | |
""" | |
super().__init__() | |
# Load model and tokenizer | |
self.save_hyperparameters() | |
self.model = MT5ForConditionalGeneration.from_pretrained( | |
model_name_or_path) if model_name_or_path is not None else None | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, | |
use_fast=True) if model_name_or_path is not None else None | |
def forward(self, **inputs): | |
""" | |
Forward inputs. | |
:param inputs: dictionary of inputs (input_ids, attention_mask, labels) | |
""" | |
return self.model(**inputs) | |
def qa(self, batch: List[dict], max_length: int = 512, **kwargs): | |
""" | |
Question answering prediction. | |
:param batch: batch of dict {question: q, context: c} | |
:param max_length: max length of output | |
""" | |
# Transform inputs | |
inputs = [f"question: {context['question']} context: {context['context']}" for context in batch] | |
# Predict | |
outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs) | |
return outputs | |
def qg(self, batch: List[str] = None, max_length: int = 512, **kwargs): | |
""" | |
Question generation prediction. | |
:param batch: batch of context with highlighted elements | |
:param max_length: max length of output | |
""" | |
# Transform inputs | |
inputs = [f"generate: {context}" for context in batch] | |
# Predict | |
outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs) | |
return outputs | |
def ae(self, batch: List[str], max_length: int = 512, **kwargs): | |
""" | |
Answer extraction prediction. | |
:param batch: list of context | |
:param max_length: max length of output | |
""" | |
# Transform inputs | |
inputs = [f"extract: {context}" for context in batch] | |
# Predict | |
outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs) | |
return outputs | |
def multitask(self, batch: List[str], max_length: int = 512, **kwargs): | |
""" | |
Answer extraction + question generation + question answering. | |
:param batch: list of context | |
:param max_length: max length of outputs | |
""" | |
# Build output dict | |
dict_batch = {'context': [context for context in batch], 'answers': [], 'questions': [], 'answers_bis': []} | |
# Iterate over context | |
for context in batch: | |
answers = self.ae(batch=[context], max_length=max_length, **kwargs)[0] | |
answers = answers.split('<sep>') | |
answers = [ans.strip() for ans in answers if ans != ' '] | |
dict_batch['answers'].append(answers) | |
for_qg = [f"{context.replace(ans, f'<hl> {ans} <hl> ')}" for ans in answers] | |
questions = self.qg(batch=for_qg, max_length=max_length, **kwargs) | |
dict_batch['questions'].append(questions) | |
new_answers = self.qa([{'context': context, 'question': question} for question in questions], | |
max_length=max_length, **kwargs) | |
dict_batch['answers_bis'].append(new_answers) | |
return dict_batch | |
def predict(self, inputs, max_length, **kwargs): | |
""" | |
Inference processing. | |
:param inputs: list of inputs | |
:param max_length: max_length of outputs | |
""" | |
# Tokenize inputs | |
inputs = self.tokenizer(inputs, max_length=max_length, padding='max_length', truncation=True, | |
return_tensors="pt") | |
# Retrieve input_ids and attention_mask | |
input_ids = inputs.input_ids.to(self.model.device) | |
attention_mask = inputs.attention_mask.to(self.model.device) | |
# Predict | |
outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, | |
**kwargs) | |
# Decode outputs | |
predictions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
return predictions | |