File size: 2,193 Bytes
91ee34e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from transformers import Pipeline

class MyPipeline(Pipeline): 
    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {} 
        if "max_length" in kwargs:
          preprocess_kwargs["max_length"] = kwargs["max_length"]
        if "num_beams" in kwargs:
          preprocess_kwargs["num_beams"] = kwargs["num_beams"]

        return preprocess_kwargs, {}, {}
    def preprocess(self, inputs, **kwargs):
           inputs = re.sub(r'[^A-Za-z가-힣,<>0-9:&# ]', '', inputs)
           inputs = "질문 생성: <unused0>"+inputs
           
           input_ids =  [tokenizer.bos_token_id] + tokenizer.encode(inputs) + [tokenizer.eos_token_id] 
           return {"inputs":torch.tensor([input_ids]),'max_length':kwargs['max_length'],'num_beams':kwargs['num_beams'] }

    def _forward(self, model_inputs):
            res_ids = model.generate(
                model_inputs['inputs'], 
                max_length=model_inputs['max_length'],
                num_beams=model_inputs['num_beams'],
                eos_token_id=tokenizer.eos_token_id,
                bad_words_ids=[[tokenizer.unk_token_id]]
            )
            return {"logits": res_ids}

    def postprocess(self, model_outputs):
            a = tokenizer.batch_decode(model_outputs["logits"].tolist())[0]
            out_question = a.replace('<s>', '').replace('</s>', '')            
            return out_question

    def _inference(self,paragraph,**kwargs):
      input_ids = self.preprocess(paragraph,**kwargs)
      reds_ids = self._forward(input_ids)
      out_question = self.postprocess(reds_ids)
      return out_question

    def make_question(self, text, **kwargs):
      words = text.split(" ")
      frame_size = kwargs['frame_size']
      hop_length = kwargs['hop_length']
      steps = round((len(words)-frame_size)/hop_length) + 1
      outs = []
      for step in range(steps):
          try:
              script = " ".join(words[step*hop_length:step*hop_length+frame_size])
          except:
              script = " ".join(words[(1+step)*hop_length:])
              
          outs.append(self._inference(script,**kwargs))
          #if step>4:
          #  break
      return outs