sunwaee commited on
Commit
c1af279
1 Parent(s): 3dac436

added model script

Browse files
Files changed (1) hide show
  1. mt5.py +141 -0
mt5.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding:utf-8
2
+ """
3
+ Filename: mt5.py
4
+ Author: @DvdNss
5
+
6
+ Created on 12/30/2021
7
+ """
8
+
9
+ from typing import List
10
+
11
+ from pytorch_lightning import LightningModule
12
+ from transformers import MT5ForConditionalGeneration, AutoTokenizer
13
+
14
+
15
+ class MT5(LightningModule):
16
+ """
17
+ Google MT5 transformer class.
18
+ """
19
+
20
+ def __init__(self, model_name_or_path: str = None):
21
+ """
22
+ Initialize module.
23
+
24
+ :param model_name_or_path: model name
25
+ """
26
+
27
+ super().__init__()
28
+
29
+ # Load model and tokenizer
30
+ self.save_hyperparameters()
31
+ self.model = MT5ForConditionalGeneration.from_pretrained(
32
+ model_name_or_path) if model_name_or_path is not None else None
33
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
34
+ use_fast=True) if model_name_or_path is not None else None
35
+
36
+ def forward(self, **inputs):
37
+ """
38
+ Forward inputs.
39
+
40
+ :param inputs: dictionary of inputs (input_ids, attention_mask, labels)
41
+ """
42
+
43
+ return self.model(**inputs)
44
+
45
+ def qa(self, batch: List[dict], max_length: int = 512, **kwargs):
46
+ """
47
+ Question answering prediction.
48
+
49
+ :param batch: batch of dict {question: q, context: c}
50
+ :param max_length: max length of output
51
+ """
52
+
53
+ # Transform inputs
54
+ inputs = [f"question: {context['question']} context: {context['context']}" for context in batch]
55
+
56
+ # Predict
57
+ outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)
58
+
59
+ return outputs
60
+
61
+ def qg(self, batch: List[str] = None, max_length: int = 512, **kwargs):
62
+ """
63
+ Question generation prediction.
64
+
65
+ :param batch: batch of context with highlighted elements
66
+ :param max_length: max length of output
67
+ """
68
+
69
+ # Transform inputs
70
+ inputs = [f"generate: {context}" for context in batch]
71
+
72
+ # Predict
73
+ outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)
74
+
75
+ return outputs
76
+
77
+ def ae(self, batch: List[str], max_length: int = 512, **kwargs):
78
+ """
79
+ Answer extraction prediction.
80
+
81
+ :param batch: list of context
82
+ :param max_length: max length of output
83
+ """
84
+
85
+ # Transform inputs
86
+ inputs = [f"extract: {context}" for context in batch]
87
+
88
+ # Predict
89
+ outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)
90
+
91
+ return outputs
92
+
93
+ def multitask(self, batch: List[str], max_length: int = 512, **kwargs):
94
+ """
95
+ Answer extraction + question generation + question answering.
96
+
97
+ :param batch: list of context
98
+ :param max_length: max length of outputs
99
+ """
100
+
101
+ # Build output dict
102
+ dict_batch = {'context': [context for context in batch], 'answers': [], 'questions': [], 'answers_bis': []}
103
+
104
+ # Iterate over context
105
+ for context in batch:
106
+ answers = self.ae(batch=[context], max_length=max_length, **kwargs)[0]
107
+ answers = answers.split('<sep>')
108
+ answers = [ans.strip() for ans in answers if ans != ' ']
109
+ dict_batch['answers'].append(answers)
110
+ for_qg = [f"{context.replace(ans, f'<hl> {ans} <hl> ')}" for ans in answers]
111
+ questions = self.qg(batch=for_qg, max_length=max_length, **kwargs)
112
+ dict_batch['questions'].append(questions)
113
+ new_answers = self.qa([{'context': context, 'question': question} for question in questions],
114
+ max_length=max_length, **kwargs)
115
+ dict_batch['answers_bis'].append(new_answers)
116
+ return dict_batch
117
+
118
+ def predict(self, inputs, max_length, **kwargs):
119
+ """
120
+ Inference processing.
121
+
122
+ :param inputs: list of inputs
123
+ :param max_length: max_length of outputs
124
+ """
125
+
126
+ # Tokenize inputs
127
+ inputs = self.tokenizer(inputs, max_length=max_length, padding='max_length', truncation=True,
128
+ return_tensors="pt")
129
+
130
+ # Retrieve input_ids and attention_mask
131
+ input_ids = inputs.input_ids.to(self.model.device)
132
+ attention_mask = inputs.attention_mask.to(self.model.device)
133
+
134
+ # Predict
135
+ outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=max_length,
136
+ **kwargs)
137
+
138
+ # Decode outputs
139
+ predictions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
140
+
141
+ return predictions