|
from typing import Any, Dict |
|
from transformers import Pipeline, AutoModel, AutoTokenizer |
|
from transformers.pipelines.base import GenericTensor, ModelOutput |
|
|
|
|
|
class HiveTokenClassification(Pipeline): |
|
def _sanitize_parameters(self, **kwargs): |
|
forward_parameters = {} |
|
if "output_style" in kwargs: |
|
forward_parameters["output_style"] = kwargs["output_style"] |
|
return {}, forward_parameters, {} |
|
|
|
def preprocess(self, input_: Any, **preprocess_parameters: Dict) -> Dict[str, GenericTensor]: |
|
return input_ |
|
|
|
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput: |
|
return self.model.predict(input_tensors, self.tokenizer, **forward_parameters) |
|
|
|
def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: Dict) -> Any: |
|
return {"output": model_outputs, "model_length": len(model_outputs)} |
|
|