mogaio's picture
Update handler.py
db91f60 verified
raw
history blame
No virus
1.95 kB
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from peft import PeftModel
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
base_model_name = "snorkelai/Snorkel-Mistral-PairRM-DPO"
lora_adaptor = "mogaio/Snorkel-Mistral-PairRM-DPO-Freakonomics_MTD-TCD-Lora"
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
self.model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=self.bnb_config,
device_map="auto",
)
self.model.config.use_cache = False
self.inference_model = PeftModel.from_pretrained(self.model, lora_adaptor, from_transformers=True)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
output = self.inference_model.generate(input_ids=inputs["input_ids"],pad_token_id=self.tokenizer.pad_token_id, max_new_tokens=256, do_sample=True, temperature=0.9, top_p=0.9, repetition_penalty=1.5, early_stopping=True, length_penalty = -0.3, num_beams=5, num_return_sequences=1)
response_raw = self.tokenizer.batch_decode(output.detach().cpu().numpy(), skip_special_tokens=True)
response_ls = response_raw[0].split('>>')
response_ = response_ls[1].split('<assistant>:')[1]
response_ = response_.split('<user>')[0]
response_ = response_.split('Instruction:')[0]
response_ = response_.replace('\n','')
response = '<assistant>:' + response_.strip()
return [{"generated_reply": response}]