mogaio's picture
Update handler.py
5bed61f verified
raw
history blame
No virus
2.62 kB
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from peft import PeftModel
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
base_model_name = "snorkelai/Snorkel-Mistral-PairRM-DPO"
lora_adaptor = "mogaio/Snorkel-Mistral-PairRM-DPO-Freakonomics_MTD-TCD-Lora"
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
self.model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=self.bnb_config,
device_map="auto",
)
self.model.config.use_cache = False
self.inference_model = PeftModel.from_pretrained(self.model, lora_adaptor, from_transformers=True)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
INTRO = "A chat between a curious user and a human like artificial intelligence assistant. The assistant gives helpful, intelligent, detailed, and polite answers to the user's questions."
prompt = ""
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
chat_history = ' \n '.join(str(x) for x in inputs)
prompt = INTRO+'\n ' + chat_history
# preprocess
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = tokenizer(prompt+' \n >> <assistant>:', return_tensors="pt").to(device)
inputs = {k: v.to('cuda') for k, v in inputs.items()}
output = self.inference_model.generate(input_ids=inputs["input_ids"],pad_token_id=self.tokenizer.pad_token_id, max_new_tokens=256, do_sample=True, temperature=0.9, top_p=0.9, repetition_penalty=1.5, early_stopping=True, length_penalty = -0.3, num_beams=5, num_return_sequences=1)
response_raw = self.tokenizer.batch_decode(output.detach().cpu().numpy(), skip_special_tokens=True)
response_ls = response_raw[0].split('>>')
response_ = response_ls[1].split('<assistant>:')[1]
response_ = response_.split('<user>')[0]
response_ = response_.split('Instruction:')[0]
response_ = response_.replace('\n','')
response = '<assistant>:' + response_.strip()
return [{"generated_reply": response}]