system's picture
system HF staff
Commit From AutoTrain
a7e0a39
raw
history blame
No virus
1.47 kB
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
import torch
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
config = PeftConfig.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path, torch_dtype=torch.float16, load_in_8bit=True, device_map="auto"
)
self.model = PeftModel.from_pretrained(model, path)
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
'''
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
'''
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# preprocess
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
# pass inputs with all kwargs in data
if parameters is not None:
outputs = self.model.generate(input_ids=input_ids, **parameters)
else:
outputs = self.model.generate(input_ids=input_ids)
# postprocess the prediction
prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return [{"generated_text": prediction}]