from typing import List, Optional, Union | |
from infinity.tasks import TextClassificationEndpoint, TextClassificationOutput, \ | |
TextClassificationParams | |
from optimum.onnxruntime import ORTModelForSequenceClassification | |
from transformers import pipeline, AutoTokenizer | |
class BankingEndpoint(TextClassificationEndpoint): | |
__slots__ = ("_pipeline", ) | |
def __init__(self): | |
super().__init__() | |
self._pipeline: Optional[ORTModelForSequenceClassification] = None | |
def initialize(self, **kwargs): | |
print("Initializing") | |
model = ORTModelForSequenceClassification.from_pretrained("philschmid/distilbert-onnx-banking77") | |
tokenizer = AutoTokenizer.from_pretrained("philschmid/distilbert-onnx-banking77") | |
self._pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer) | |
print("INITIALIZED") | |
def handle( | |
self, | |
inputs: Union[str, List[str]], | |
parameters: TextClassificationParams | |
) -> List[TextClassificationOutput]: | |
return self._pipeline(inputs, **parameters) | |