import torch from typing import Dict, List, Any from transformers import AutoTokenizer, BitsAndBytesConfig from peft import AutoPeftModelForCausalLM def parse_output(text): marker = "### Response:" if marker in text: pos = text.find(marker) + len(marker) else: pos = 0 return text[pos:].replace("", "").replace("", "").strip() class EndpointHandler: def __init__(self, path="./", use_bnb=True): bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) self.model = AutoPeftModelForCausalLM.from_pretrained( path, load_in_8bit=False, quantization_config=bnb_config, device_map="auto" ) self.tokenizer = AutoTokenizer.from_pretrained(path) print("Memory footprint: ", self.model.get_memory_footprint()) print("Device map: ", self.model.hf_device_map) def __call__(self, data: Any) -> List[List[Dict[str, str]]]: inputs = data.get("inputs", data) prompt = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction: \n{inputs}\n\n### Response: \n" parameters = data.get("parameters", {}) with torch.no_grad(): inputs = self.tokenizer( prompt, return_tensors="pt", return_token_type_ids=False ).to(self.model.device) outputs = self.model.generate(**inputs, **parameters) return { "generated_text": parse_output( self.tokenizer.decode(outputs[0].tolist(), skip_special_tokens=True) ) }