from typing import Dict, List, Any from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftConfig, PeftModel class EndpointHandler(): def __init__(self, path=""): config = PeftConfig.from_pretrained(path) model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map='auto') self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) # Load the Lora model self.model = PeftModel.from_pretrained(model, path) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: prompt (:obj:`str`): temperature (:obj:`float`, `optional`, defaults to 0.5): eos_token_id (:obj:`int`, `optional`, defaults to tokenizer.eos_token_id): early_stopping (:obj:`bool`, `optional`, defaults to `True`): repetition_penalty (:obj:`float`, `optional`, defaults to 0.3): Return: A :obj:`str` : generated sequences """ # Get inputs prompt = data.pop("prompt", None) temperature = data.pop("temperature", 0.5) eos_token_id = data.pop("eos_token_id", self.tokenizer.eos_token_id) early_stopping = data.pop('early_stopping', True) repetition_penalty = data.pop('repetition_penalty', 0.3) max_new_tokens = data.pop('max_new_tokens', 100) if prompt is None: raise ValueError("No prompt provided.") # Run prediction inputs = self.tokenizer(prompt, return_tensors="pt") prediction = self.model.generate( **inputs, temperature=temperature, eos_token_id=eos_token_id, early_stopping=early_stopping, repetition_penalty=repetition_penalty, max_new_tokens=max_new_tokens ) return prediction