File size: 1,205 Bytes
db70e6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from transformers import Pipeline
from transformers.utils import ModelOutput
from transformers import PreTrainedModel, Pipeline
from typing import Any, Dict, List
class QApipeline(Pipeline):
def __init__(
self,
model: PreTrainedModel,
**kwargs
):
super().__init__(
model=model,
**kwargs
)
def __call__(
self,
question: str,
context: str,
**kwargs
) -> Dict[str, Any]:
inputs = {
"question": question,
"context": context
}
outputs = self.model(**inputs)
answer = self._process_output(outputs)
return {"answer": answer}
def _process_output(
self,
outputs: Any
) -> str:
answer = outputs
return answer
def _sanitize_parameters(self, **kwargs):
print(**kwargs)
return {}, {}, {}
def preprocess(self, inputs):
return inputs
def postprocess(self, outputs):
return outputs
def _forward(self, input_tensors, **forward_parameters: Dict) -> ModelOutput:
return super()._forward(input_tensors, **forward_parameters)
|