File size: 5,468 Bytes
fa9baae
5d3c45e
fa9baae
7738cde
5d3c45e
 
 
 
 
 
 
48a26a2
 
fa9baae
5d3c45e
8fd48be
fa9baae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from typing import Dict, List, Any
import itertools
from nltk import sent_tokenize

class PreTrainedPipeline():
    def __init__(self, path=""):
        # IMPLEMENT_THIS
        # Preload all the elements you are going to need at inference.
        # For instance your model, processors, tokenizer that might be needed.
        # This function is only called once, so do all the heavy processing I/O here"""
        self.model = AutoModelForSeq2SeqLM.from_pretrained(path)        
        self.tokenizer = AutoTokenizer.from_pretrained(path)


    def __call__(self, inputs: str):
        if len(inputs) == 0: return []
        inputs = " ".join(inputs.split())
        sents, answers = self._extract_answers(inputs)
        flat_answers = list(itertools.chain(*answers))
        
        if len(flat_answers) == 0:
          return []

        qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers)
        
        qg_inputs = [example['source_text'] for example in qg_examples]
        questions = self._generate_questions(qg_inputs)
        output = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, questions)]
        output = self.clean_generated_QAs(output)
        return output  

    def _extract_answers(self, context):
        print("_extract_answers")
        sents, inputs = self._prepare_inputs_for_ans_extraction(context)
        inputs = self._tokenize(inputs, padding=True, truncation=True)

        outs = self.model.generate(
            input_ids=inputs['input_ids'].to(self.device), 
            attention_mask=inputs['attention_mask'].to(self.device), 
            max_length=32,
        )
        
        dec = [self.tokenizer.decode(ids, skip_special_tokens=False) for ids in outs]
        answers = [item.split('<sep>') for item in dec]
        answers = [i[:-1] for i in answers]
        
        return sents, answers

    
    def _prepare_inputs_for_ans_extraction(self, text):
        print("_prepare_inputs_for_ans_extraction")
        sents = sent_tokenize(text)

        inputs = []
        for i in range(len(sents)):
            source_text = "extract answers:"
            for j, sent in enumerate(sents):
                if i == j:
                    sent = "<hl> %s <hl>" % sent
                source_text = "%s %s" % (source_text, sent)
                source_text = source_text.strip()
            
            if self.model_type == "t5":
              source_text = source_text + " </s>"
            inputs.append(source_text)

        return sents, inputs      
          
    def _tokenize(self,
        inputs,
        padding=True,
        truncation=True,
        add_special_tokens=True,
        max_length=512
    ):
        inputs = self.tokenizer.batch_encode_plus(
            inputs, 
            max_length=max_length,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            padding="max_length" if padding else False,
            pad_to_max_length=padding,
            return_tensors="pt"
        )
        return inputs        

    def _generate_questions(self, inputs):
        print("_generate_questions")
        inputs = self._tokenize(inputs, padding=True, truncation=True)
        
        outs = self.model.generate(
            input_ids=inputs['input_ids'].to(self.device), 
            attention_mask=inputs['attention_mask'].to(self.device), 
            max_length=32,
            num_beams=4,
        )
        
        questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
        return questions


    
    def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers):
        print("_prepare_inputs_for_qg_from_answers_hl")
        inputs = []
        for i, answer in enumerate(answers):
            if len(answer) == 0: continue
            for answer_text in answer:
                sent = sents[i]
                sents_copy = sents[:]
                answer_text = self.remove_pad(answer_text)
                answer_text = answer_text.strip()
                print("Answer", answer)
                print("Answer text", answer_text)
                
                try:
                  ans_start_idx = sent.lower().index(answer_text.lower())
                except ValueError:
                  # Means the answer is not in the sentence so we skip this one
                  continue
                
                sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_start_idx + len(answer_text): ]}"
                sents_copy[i] = sent
                
                source_text = " ".join(sents_copy)
                source_text = f"generate question: {source_text}" 
                if self.model_type == "t5":
                    source_text = source_text + " </s>"
                
                inputs.append({"answer": answer_text, "source_text": source_text})
        
        return inputs

    def clean_generated_QAs(self, generated_QAs):
      clean_QAs = []
      answers_used = set()
      # Only allow 1 question per answer, take the first case of it
      for qa in generated_QAs:
        if qa['answer'] in answers_used:
          break
        answers_used.add(qa['answer'])
        clean_QAs.append(qa)
      return clean_QAs

    def remove_pad(self, str):
      if "<pad>" in str:
        return str.replace("<pad>", "")
      return str