flan-ul2-20b-fp16 / handler.py
philschmid's picture
philschmid HF staff
Create handler.py
002f70f
raw
history blame contribute delete
No virus
1.17 kB
from typing import Dict, List, Any
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.model = AutoModelForSeq2SeqLM.from_pretrained(path, device_map="auto", load_in_8bit=True)
self.tokenizer = AutoTokenizer.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:obj:):
includes the deserialized image file as PIL.Image
"""
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# preprocess
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
# pass inputs with all kwargs in data
if parameters is not None:
outputs = self.model.generate(input_ids, **parameters)
else:
outputs = self.model.generate(input_ids)
# postprocess the prediction
prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return [{"generated_text": prediction}]