Adapters
Inference Endpoints
llm-tolkien / handler.py
JeremyArancio's picture
Update handler
7bf309f
raw history blame
No virus
1.66 kB
from typing import Dict, Any
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, PeftModel
import torch.cuda
LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
device = "cuda" if torch.cuda.is_available() else "cpu"
class EndpointHandler():
def __init__(self, path=""):
config = PeftConfig.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map='auto')
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# Load the Lora model
self.model = PeftModel.from_pretrained(model, path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data (Dict): The payload with the text prompt and generation parameters.
"""
LOGGER.info(f"Received data: {data}")
# Get inputs
prompt = data.pop("inputs", None)
parameters = data.pop("parameters", None)
if prompt is None:
raise ValueError("Missing prompt.")
# Preprocess
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Forward
LOGGER.info(f"Start generation.")
if parameters is not None:
output = self.model.generate(input_ids=input_ids, **parameters)
else:
output = self.model.generate(input_ids=input_ids)
# Postprocess
prediction = self.tokenizer.decode(output[0])
LOGGER.info(f"Generated text: {prediction}")
return {"generated_text": prediction}