ammarnasr's picture
handler
f2fdfe8
raw
history blame contribute delete
No virus
1.49 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig
from peft import PeftModel
import torch.cuda
from typing import Any, Dict
device = "cuda" if torch.cuda.is_available() else "cpu"
class EndpointHandler():
def __init__(self, path=""):
config = PeftConfig.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# Load the Lora model
self.model = PeftModel.from_pretrained(model, path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data (Dict): The payload with the text prompt
and generation parameters.
"""
# Get inputs
prompt = data.pop("inputs", None)
parameters = data.pop("parameters", None)
if prompt is None:
raise ValueError("Missing prompt.")
# Preprocess
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Forward
if parameters is not None:
output = self.model.generate(input_ids=input_ids, **parameters)
else:
output = self.model.generate(input_ids=input_ids)
# Postprocess
prediction = self.tokenizer.decode(output[0])
return {"generated_text": prediction}