File size: 1,525 Bytes
c159aef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# qapipeline.py

from transformers import PreTrainedModel, Pipeline
from typing import Any, Dict
from transformers import Pipeline
from transformers import PreTrainedTokenizer
from transformers.utils import ModelOutput

from transformers import PreTrainedModel, Pipeline
from typing import Any, Dict, List

class QApipeline(Pipeline):
    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        **kwargs
    ):
        super().__init__(
            model=model,
            **kwargs
        )

        print("in __init__")

    def __call__( self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        outputs = self.model.predict(inputs)
        answer = self._process_output(outputs)
        print("in __call___")
        return answer 

    def _process_output(self, outputs: Any) -> str:
        print("in process outputs")
        format =  {'guess': outputs[1], 'confidence': outputs[0]}
        return format
    
    def _sanitize_parameters(self, **kwargs):
        print("in sanitize params")
        return {}, {}, {}

    def preprocess(self, inputs):
        print("in preprocess")
        return inputs

    def postprocess(self, outputs):
        print("in postprocess")
        format =  {'guess': outputs[1], 'confidence': float(outputs[0])}
        return format
    
    def _forward(self, input_tensors, **forward_parameters: Dict) -> ModelOutput:
        print("in _forward")
        return super()._forward(input_tensors, **forward_parameters)