|
from transformers import Pipeline |
|
from transformers.utils import ModelOutput |
|
|
|
from transformers import PreTrainedModel, Pipeline |
|
from typing import Any, Dict, List |
|
|
|
class QApipeline(Pipeline): |
|
def __init__( |
|
self, |
|
model: PreTrainedModel, |
|
**kwargs |
|
): |
|
super().__init__( |
|
model=model, |
|
**kwargs |
|
) |
|
|
|
def __call__( |
|
self, |
|
question: str, |
|
context: str, |
|
**kwargs |
|
) -> Dict[str, Any]: |
|
inputs = { |
|
"question": question, |
|
"context": context |
|
} |
|
|
|
outputs = self.model(**inputs) |
|
|
|
answer = self._process_output(outputs) |
|
|
|
return {"answer": answer} |
|
|
|
def _process_output( |
|
self, |
|
outputs: Any |
|
) -> str: |
|
answer = outputs |
|
|
|
|
|
return answer |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
print(**kwargs) |
|
|
|
return {}, {}, {} |
|
|
|
def preprocess(self, inputs): |
|
|
|
return inputs |
|
|
|
def postprocess(self, outputs): |
|
|
|
return outputs |
|
|
|
def _forward(self, input_tensors, **forward_parameters: Dict) -> ModelOutput: |
|
return super()._forward(input_tensors, **forward_parameters) |
|
|