Spaces:
Runtime error
Runtime error
File size: 4,551 Bytes
c1af279 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# coding:utf-8
"""
Filename: mt5.py
Author: @DvdNss
Created on 12/30/2021
"""
from typing import List
from pytorch_lightning import LightningModule
from transformers import MT5ForConditionalGeneration, AutoTokenizer
class MT5(LightningModule):
"""
Google MT5 transformer class.
"""
def __init__(self, model_name_or_path: str = None):
"""
Initialize module.
:param model_name_or_path: model name
"""
super().__init__()
# Load model and tokenizer
self.save_hyperparameters()
self.model = MT5ForConditionalGeneration.from_pretrained(
model_name_or_path) if model_name_or_path is not None else None
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
use_fast=True) if model_name_or_path is not None else None
def forward(self, **inputs):
"""
Forward inputs.
:param inputs: dictionary of inputs (input_ids, attention_mask, labels)
"""
return self.model(**inputs)
def qa(self, batch: List[dict], max_length: int = 512, **kwargs):
"""
Question answering prediction.
:param batch: batch of dict {question: q, context: c}
:param max_length: max length of output
"""
# Transform inputs
inputs = [f"question: {context['question']} context: {context['context']}" for context in batch]
# Predict
outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)
return outputs
def qg(self, batch: List[str] = None, max_length: int = 512, **kwargs):
"""
Question generation prediction.
:param batch: batch of context with highlighted elements
:param max_length: max length of output
"""
# Transform inputs
inputs = [f"generate: {context}" for context in batch]
# Predict
outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)
return outputs
def ae(self, batch: List[str], max_length: int = 512, **kwargs):
"""
Answer extraction prediction.
:param batch: list of context
:param max_length: max length of output
"""
# Transform inputs
inputs = [f"extract: {context}" for context in batch]
# Predict
outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)
return outputs
def multitask(self, batch: List[str], max_length: int = 512, **kwargs):
"""
Answer extraction + question generation + question answering.
:param batch: list of context
:param max_length: max length of outputs
"""
# Build output dict
dict_batch = {'context': [context for context in batch], 'answers': [], 'questions': [], 'answers_bis': []}
# Iterate over context
for context in batch:
answers = self.ae(batch=[context], max_length=max_length, **kwargs)[0]
answers = answers.split('<sep>')
answers = [ans.strip() for ans in answers if ans != ' ']
dict_batch['answers'].append(answers)
for_qg = [f"{context.replace(ans, f'<hl> {ans} <hl> ')}" for ans in answers]
questions = self.qg(batch=for_qg, max_length=max_length, **kwargs)
dict_batch['questions'].append(questions)
new_answers = self.qa([{'context': context, 'question': question} for question in questions],
max_length=max_length, **kwargs)
dict_batch['answers_bis'].append(new_answers)
return dict_batch
def predict(self, inputs, max_length, **kwargs):
"""
Inference processing.
:param inputs: list of inputs
:param max_length: max_length of outputs
"""
# Tokenize inputs
inputs = self.tokenizer(inputs, max_length=max_length, padding='max_length', truncation=True,
return_tensors="pt")
# Retrieve input_ids and attention_mask
input_ids = inputs.input_ids.to(self.model.device)
attention_mask = inputs.attention_mask.to(self.model.device)
# Predict
outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=max_length,
**kwargs)
# Decode outputs
predictions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return predictions
|