Backedman's picture
Update QAPipeline.py
f3107e9 verified
from transformers import Pipeline
from transformers.utils import ModelOutput
from transformers import PreTrainedModel, Pipeline
from typing import Any, Dict, List
class QApipeline(Pipeline):
def __init__(
self,
model: PreTrainedModel,
**kwargs
):
super().__init__(
model=model,
**kwargs
)
print("in __init__")
def __call__( self, question: str, **kwargs) -> Dict[str, Any]:
inputs = {
"question": question,
}
outputs = self.model.predict(question)
answer = self._process_output(outputs)
print("in __call___")
return answer
def _process_output(self, outputs: Any) -> str:
print("in process outputs")
print(outputs)
format = {'guess': outputs[1], 'confidence': outputs[0]}
return format
def _sanitize_parameters(self, **kwargs):
print("in sanatize params")
return {}, {}, {}
def preprocess(self, inputs):
print("in preprocess")
return inputs
def postprocess(self, outputs):
print("in postprocess")
format = {'guess': outputs[1], 'confidence': float(outputs[0])}
return format
def _forward(self, input_tensors, **forward_parameters: Dict) -> ModelOutput:
print("in _forward")
return super()._forward(input_tensors, **forward_parameters)