Backedman's picture
Upload QApipeline
c159aef verified
raw
history blame
No virus
1.53 kB
# qapipeline.py
from transformers import PreTrainedModel, Pipeline
from typing import Any, Dict
from transformers import Pipeline
from transformers import PreTrainedTokenizer
from transformers.utils import ModelOutput
from transformers import PreTrainedModel, Pipeline
from typing import Any, Dict, List
class QApipeline(Pipeline):
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
**kwargs
):
super().__init__(
model=model,
**kwargs
)
print("in __init__")
def __call__( self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
outputs = self.model.predict(inputs)
answer = self._process_output(outputs)
print("in __call___")
return answer
def _process_output(self, outputs: Any) -> str:
print("in process outputs")
format = {'guess': outputs[1], 'confidence': outputs[0]}
return format
def _sanitize_parameters(self, **kwargs):
print("in sanitize params")
return {}, {}, {}
def preprocess(self, inputs):
print("in preprocess")
return inputs
def postprocess(self, outputs):
print("in postprocess")
format = {'guess': outputs[1], 'confidence': float(outputs[0])}
return format
def _forward(self, input_tensors, **forward_parameters: Dict) -> ModelOutput:
print("in _forward")
return super()._forward(input_tensors, **forward_parameters)