from typing import Dict, List, Any import torch from transformers import BertModel, BertTokenizerFast class EndpointHandler(): def __init__(self, path_to_model: str = '.'): # Preload all the elements you are going to need at inference. # pseudo: self.tokenizer = BertTokenizerFast.from_pretrained(path_to_model) self.model = BertModel.from_pretrained(path_to_model) self.model = self.model.eval() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ This method is called whenever a request is made to the endpoint. :param data: { inputs [str]: list of strings to be encoded } :return: A :obj:`list` | `dict`: will be serialized and returned """ inputs = self.tokenizer(data['inputs'], return_tensors = "pt", padding = True) with torch.no_grad(): outputs = self.model(**inputs) return outputs.pooler_output