llm_deploy_small / handler.py
Ozgur98's picture
Update handler.py
7dbc4f3
raw
history blame contribute delete
No virus
1.39 kB
from typing import Dict, Any
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, PeftModel
import torch.cuda
device = "cuda" if torch.cuda.is_available() else "cpu"
class EndpointHandler():
def __init__(self, path=""):
config = PeftConfig.from_pretrained("JeremyArancio/llm-tolkien")
self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map='auto')
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data (Dict): The payload with the text prompt and generation parameters.
"""
# Get inputs
prompt = data.pop("inputs", None)
parameters = data.pop("parameters", None)
if prompt is None:
raise ValueError("Missing prompt.")
# Preprocess
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Forward
if parameters is not None:
output = self.model.generate(input_ids=input_ids, **parameters)
else:
output = self.model.generate(input_ids=input_ids)
# Postprocess
prediction = self.tokenizer.decode(output[0])
return {"generated_text": prediction}