|
from typing import Dict, List, Any |
|
|
|
import torch |
|
from transformers import BertModel, BertTokenizerFast |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path_to_model: str = '.'): |
|
|
|
|
|
self.tokenizer = BertTokenizerFast.from_pretrained(path_to_model) |
|
self.model = BertModel.from_pretrained(path_to_model) |
|
self.model = self.model.eval() |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
This method is called whenever a request is made to the endpoint. |
|
:param data: { inputs [str]: list of strings to be encoded } |
|
:return: A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
inputs = self.tokenizer(data['inputs'], return_tensors = "pt", padding = True) |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
|
|
return outputs.pooler_output.tolist() |
|
|