from typing import Dict, List, Any import logging from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftConfig, PeftModel LOGGER = logging.getLogger(__name__) class EndpointHandler(): def __init__(self, path=""): config = PeftConfig.from_pretrained(path) model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map='auto') self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) # Load the Lora model self.model = PeftModel.from_pretrained(model, path) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Args: data (Dict): The payload with the text prompt and generation parameters. """ LOGGER.info(f"Received data: {data}") # Get inputs inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) LOGGER.info("Data extracted.") # Preprocess LOGGER.info(f"Start tokenizer: {inputs}") inputs_ids = self.tokenizer(inputs, return_tensors="pt").inputs_ids # Forward LOGGER.info(f"Start generation.") if parameters is not None: outputs = self.model.generate(inputs_ids, **parameters) else: outputs = self.model.generate(inputs_ids) # Postprocess prediction = self.tokenizer.decode(outputs[0]) LOGGER.info(f"Generated text: {prediction}") return {"generated_text": prediction}