File size: 6,134 Bytes
fa9baae
5d3c45e
fa9baae
7738cde
3aabb87
af0421d
5d3c45e
 
1f06868
5d3c45e
 
 
 
3aabb87
af0421d
3aabb87
48a26a2
1f06868
 
3aabb87
 
 
 
 
fa9baae
5d3c45e
3aabb87
fa9baae
 
 
 
 
 
 
 
3aabb87
 
 
 
 
 
fa9baae
3aabb87
fa9baae
 
3aabb87
 
 
 
 
 
 
 
 
 
 
fa9baae
 
 
 
 
 
 
 
 
 
 
 
 
3aabb87
 
 
fa9baae
 
3aabb87
fa9baae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3aabb87
fa9baae
 
 
 
3aabb87
 
 
fa9baae
 
 
 
 
 
 
3aabb87
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from typing import Dict, List, Any
import itertools
from nltk import sent_tokenize
# import torch
import nltk

class PreTrainedPipeline():

    def __init__(self, path=""):
        # IMPLEMENT_THIS
        # Preload all the elements you are going to need at inference.
        # For instance your model, processors, tokenizer that might be needed.
        # This function is only called once, so do all the heavy processing I/O here"""        
        nltk.download('punkt')
        self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        
        self.model_type="t5"
        # self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device =  "cpu"
        
        self.model.to(self.device)
        


    def __call__(self, inputs: str, max_words_per_answer: int = 3):
        if len(inputs) == 0: return []
        inputs = " ".join(inputs.split())
        sents, answers = self._extract_answers(inputs)
        flat_answers = list(itertools.chain(*answers))
        
        if len(flat_answers) == 0:
          return []

        questions, qg_examples = self.prepare_and_generate_questions(sents, answers)
        output = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, questions)]
        output = self.clean_generated_QAs(output, max_words_per_answer)
        return output  
    
    def prepare_and_generate_questions(self, sents, answers):
        qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers)
      
        qg_inputs = [example['source_text'] for example in qg_examples]
        questions = self._generate_questions(qg_inputs)
        return questions, qg_examples


    def clean_answers_list_of_lists(self, answers):
        clean_answers = []
        for answer_list in answers:
            answer_list = answer_list[:-1]
            answer_list = list(set([a.strip() for a in answer_list]))
            clean_answers.append(answer_list)
        return clean_answers


    def _extract_answers(self, context):
        sents, inputs = self._prepare_inputs_for_ans_extraction(context)
        inputs = self._tokenize(inputs, padding=True, truncation=True)

        outs = self.model.generate(
            input_ids=inputs['input_ids'].to(self.device), 
            attention_mask=inputs['attention_mask'].to(self.device), 
            max_length=32,
        )
        
        dec = [self.tokenizer.decode(ids, skip_special_tokens=False) for ids in outs]
        answers = [item.split('<sep>') for item in dec]

        answers = self.clean_answers_list_of_lists(answers)

        return sents, answers


    
    def _prepare_inputs_for_ans_extraction(self, text):
        sents = sent_tokenize(text)

        inputs = []
        for i in range(len(sents)):
            source_text = "extract answers:"
            for j, sent in enumerate(sents):
                if i == j:
                    sent = "<hl> %s <hl>" % sent
                source_text = "%s %s" % (source_text, sent)
                source_text = source_text.strip()
            
            if self.model_type == "t5":
              source_text = source_text + " </s>"
            inputs.append(source_text)

        return sents, inputs      
          
    def _tokenize(self,
        inputs,
        padding=True,
        truncation=True,
        add_special_tokens=True,
        max_length=512
    ):
        inputs = self.tokenizer.batch_encode_plus(
            inputs, 
            max_length=max_length,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            padding="max_length" if padding else False,
            pad_to_max_length=padding,
            return_tensors="pt"
        )
        return inputs        

    def _generate_questions(self, inputs):
        inputs = self._tokenize(inputs, padding=True, truncation=True)
        
        outs = self.model.generate(
            input_ids=inputs['input_ids'].to(self.device), 
            attention_mask=inputs['attention_mask'].to(self.device), 
            max_length=32,
            num_beams=4,
        )
        
        questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
        return questions
    
    def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers):
        inputs = []
        for i, answer in enumerate(answers):
            if len(answer) == 0: continue
            for answer_text in answer:
                sent = sents[i]
                sents_copy = sents[:]
                answer_text = self.remove_pad(answer_text)
                answer_text = answer_text.strip()
                
                try:
                  ans_start_idx = sent.lower().index(answer_text.lower())
                except ValueError:
                  # Means the answer is not in the sentence so we skip this one
                  continue
                
                sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_start_idx + len(answer_text): ]}"
                sents_copy[i] = sent
                
                source_text = " ".join(sents_copy)
                source_text = f"generate question: {source_text}" 
                if self.model_type == "t5":
                    source_text = source_text + " </s>"
                
                inputs.append({"answer": answer_text, "source_text": source_text})
        
        return inputs

    def clean_generated_QAs(self, generated_QAs, max_words_per_answer):
      clean_QAs = []
      answers_used = set()
      # Only allow 1 question per answer, take the first case of it
      for qa in generated_QAs:
        answer_word_length = len(qa['answer'].strip().split())
        if qa['answer'] in answers_used or answer_word_length > max_words_per_answer:
          continue
        answers_used.add(qa['answer'])
        clean_QAs.append(qa)
      return clean_QAs

    def remove_pad(self, str):
      if "<pad>" in str:
        return str.replace("<pad>", "")
      return str