File size: 4,551 Bytes
c1af279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# coding:utf-8
"""
Filename: mt5.py
Author: @DvdNss

Created on 12/30/2021
"""

from typing import List

from pytorch_lightning import LightningModule
from transformers import MT5ForConditionalGeneration, AutoTokenizer


class MT5(LightningModule):
    """
    Google MT5 transformer class.
    """

    def __init__(self, model_name_or_path: str = None):
        """
        Initialize module.

        :param model_name_or_path: model name
        """

        super().__init__()

        # Load model and tokenizer
        self.save_hyperparameters()
        self.model = MT5ForConditionalGeneration.from_pretrained(
            model_name_or_path) if model_name_or_path is not None else None
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
                                                       use_fast=True) if model_name_or_path is not None else None

    def forward(self, **inputs):
        """
        Forward inputs.

        :param inputs: dictionary of inputs (input_ids, attention_mask, labels)
        """

        return self.model(**inputs)

    def qa(self, batch: List[dict], max_length: int = 512, **kwargs):
        """
        Question answering prediction.

        :param batch: batch of dict {question: q, context: c}
        :param max_length: max length of output
        """

        # Transform inputs
        inputs = [f"question: {context['question']}  context: {context['context']}" for context in batch]

        # Predict
        outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)

        return outputs

    def qg(self, batch: List[str] = None, max_length: int = 512, **kwargs):
        """
        Question generation prediction.

        :param batch: batch of context with highlighted elements
        :param max_length: max length of output
        """

        # Transform inputs
        inputs = [f"generate: {context}" for context in batch]

        # Predict
        outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)

        return outputs

    def ae(self, batch: List[str], max_length: int = 512, **kwargs):
        """
        Answer extraction prediction.

        :param batch: list of context
        :param max_length: max length of output
        """

        # Transform inputs
        inputs = [f"extract: {context}" for context in batch]

        # Predict
        outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)

        return outputs

    def multitask(self, batch: List[str], max_length: int = 512, **kwargs):
        """
        Answer extraction + question generation + question answering.

        :param batch: list of context
        :param max_length: max length of outputs
        """

        # Build output dict
        dict_batch = {'context': [context for context in batch], 'answers': [], 'questions': [], 'answers_bis': []}

        # Iterate over context
        for context in batch:
            answers = self.ae(batch=[context], max_length=max_length, **kwargs)[0]
            answers = answers.split('<sep>')
            answers = [ans.strip() for ans in answers if ans != ' ']
            dict_batch['answers'].append(answers)
            for_qg = [f"{context.replace(ans, f'<hl> {ans} <hl> ')}" for ans in answers]
            questions = self.qg(batch=for_qg, max_length=max_length, **kwargs)
            dict_batch['questions'].append(questions)
            new_answers = self.qa([{'context': context, 'question': question} for question in questions],
                                  max_length=max_length, **kwargs)
            dict_batch['answers_bis'].append(new_answers)
        return dict_batch

    def predict(self, inputs, max_length, **kwargs):
        """
        Inference processing.

        :param inputs: list of inputs
        :param max_length: max_length of outputs
        """

        # Tokenize inputs
        inputs = self.tokenizer(inputs, max_length=max_length, padding='max_length', truncation=True,
                                return_tensors="pt")

        # Retrieve input_ids and attention_mask
        input_ids = inputs.input_ids.to(self.model.device)
        attention_mask = inputs.attention_mask.to(self.model.device)

        # Predict
        outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=max_length,
                                      **kwargs)

        # Decode outputs
        predictions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

        return predictions