from typing import Dict, List, Any from optimum.intel import OVModelForSeq2SeqLM from transformers import AutoTokenizer INSTRUCTION = "rewrite: " generation_config = { "max_new_tokens": 16, "use_cache": True, "temperature": 0.6, "do_sample": True, "top_p": 0.95, } class EndpointHandler: def __init__(self, path="."): # Preload all the elements you are going to need at inference. # pseudo: self.model = OVModelForSeq2SeqLM.from_pretrained( path, use_cache=True, use_io_binding=False ) self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ inputs = data.pop("inputs", data) parameters = data.pop("parameters", generation_config) inputs = self.tokenizer( ["{} {}".format(INSTRUCTION, inputs)], padding=False, return_tensors="pt", max_length=20, truncation=True, ) outputs = self.model.generate(**inputs, **parameters) return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)