YasaminAbb's picture
Update handler.py
43d5442 verified
raw
history blame contribute delete
No virus
1.86 kB
from typing import Any, Dict
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftConfig, PeftModel
class EndpointHandler:
def __init__(self,path=""):
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dtype
)
config = PeftConfig.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
quantization_config=bnb_config,
device_map="auto" ,
torch_dtype=dtype,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
self.tokenizer=tokenizer
self.model = PeftModel.from_pretrained(model, path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
prompt = data.pop("inputs", data)
prompt = f"""
### Instruction:
Summarize the following conversation.
​
### Input:
{prompt}
​
### Summary:
""".strip()
input_ids = self.tokenizer(prompt, return_tensors='pt',truncation=True).input_ids.cuda()
outputs = self.model.generate(input_ids=input_ids, max_new_tokens=200, )
output = self.tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]
return output