gpt-j-6B-fp16-sharded / handler.py
philschmid's picture
philschmid HF staff
added custom handler for sharded loading
9327b57
raw
history blame
No virus
970 Bytes
import torch
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# check for GPU
device = 0 if torch.cuda.is_available() else -1
class EndpointHandler:
def __init__(self, path=""):
# load the model
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True)
# create inference pipeline
self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# pass inputs with all kwargs in data
if parameters is not None:
prediction = self.pipeline(inputs, **parameters)
else:
prediction = self.pipeline(inputs)
# postprocess the prediction
return prediction