MultiTrickFox's picture
Update handler.py
bf2292e verified
from typing import Dict, List, Any
import time
import torch
from transformers import AutoTokenizer, AutoModel
#
class EndpointHandler:
def __init__(self, path=''):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModel.from_pretrained(path, load_in_8bit=True)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
inputs = data.pop('inputs', data)
parameters = data.pop('parameters', {})
starting_time = time.time()
tokenized = self.tokenizer(inputs, return_tensors='pt')
out = self.model.generate(tokenized.to('cuda'), **parameters).to('cpu')
detokenized = self.tokenizer.batch_decode(out)
ending_time = time.time()
return [{'generated_text': detokenized, 'generation_time': ending_time-starting_time}]