from typing import Dict from transformers import AutoModelForCausalLM, AutoTokenizer import torch class EndpointHandler(): def __init__(self, path=""): # load the model self.model = AutoModelForCausalLM.from_pretrained( "gpt2", torch_dtype=torch.float16, output_hidden_states=True ) self.model = self.model.cuda() self.tokenizer = AutoTokenizer.from_pretrained("gpt2") def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]: """ Args: data (:obj:): includes the deserialized audio file as bytes Return: A :obj:`dict`:. base64 encoded image """ # process input inputs = data.pop("inputs", data) all_logits = [] for doc in inputs: tokenized = self.tokenizer( inputs, return_tensors="pt", truncation=True, max_length=512, ) token_ids, token_mask = tokenized.input_ids.cuda(), tokenized.attention_mask.cuda() with torch.no_grad(): out = self.model(token_ids, attention_mask=token_mask) meaned_logits = (out.logits * token_mask.unsqueeze(-1)).sum(1) / token_mask.sum( 1 ).unsqueeze(-1) sorted_logits = torch.sort(out.logits).values mean_sorted_logits = (sorted_logits * token_mask.unsqueeze(-1)).sum( 1 ) / token_mask.sum(1).unsqueeze(-1) all_logits.append(meaned_logits.cpu().numpy().tolist()) # postprocess the prediction return {"logits": all_logits}