from typing import Any, Dict import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftConfig, PeftModel class EndpointHandler: def __init__(self,path=""): dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=dtype ) config = PeftConfig.from_pretrained(path) model = AutoModelForCausalLM.from_pretrained( config.base_model_name_or_path, return_dict=True, quantization_config=bnb_config, device_map="auto" , torch_dtype=dtype, trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) tokenizer.pad_token = tokenizer.eos_token self.tokenizer=tokenizer self.model = PeftModel.from_pretrained(model, path) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: prompt = data.pop("inputs", data) prompt = f""" ### Instruction: Summarize the following conversation. ​ ### Input: {prompt} ​ ### Summary: """.strip() input_ids = self.tokenizer(prompt, return_tensors='pt',truncation=True).input_ids.cuda() outputs = self.model.generate(input_ids=input_ids, max_new_tokens=200, ) output = self.tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):] return output