Tidzo's picture
Update handler.py
8cdab11
raw
history blame contribute delete
No virus
2.56 kB
import logging
from typing import Any, Dict
import torch.cuda
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
device = "cuda" if torch.cuda.is_available() else "cpu"
class EndpointHandler():
def __init__(self, path=""):
config = PeftConfig.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
load_in_8bit=True,
trust_remote_code=True,
device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(
config.base_model_name_or_path, trust_remote_code=True)
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load the Lora model
self.model = PeftModel.from_pretrained(model, path, torch_dtype=model.dtype)
self.model.eos_token_id = self.tokenizer.eos_token_id
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data (Dict): The payload with the text prompt and generation parameters.
"""
LOGGER.info(f"Received data: {data}")
# Get inputs
prompt = data.pop("inputs", None)
parameters = data.pop("parameters", None)
if prompt is None:
raise ValueError("Missing prompt.")
# Preprocess
encoding = self.tokenizer(
prompt, return_tensors="pt")
input_ids = encoding.input_ids.to(device)
attention_mask = encoding.attention_mask
# Forward
LOGGER.info(f"Start generation.")
if parameters is not None:
output = self.model.generate(
input_ids=input_ids, attention_mask=attention_mask, **parameters)
LOGGER.info("Parameters have been giving for model generation")
else:
output = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=256,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id,
)
LOGGER.info("Parameters have not been giving for model generation")
# Postprocess
prediction = self.tokenizer.decode(output[0], skip_special_tokens=True)
LOGGER.info(f"Generated text: {prediction}")
return [{"generated_text": prediction}]