mogaio's picture
Update handler.py
65e5eaa verified
raw
history blame
No virus
2.15 kB
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from peft import PeftModel
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
base_model_name = "snorkelai/Snorkel-Mistral-PairRM-DPO"
lora_adaptor = "mogaio/Snorkel-Mistral-PairRM-DPO-Freakonomics_MTD-TCD-Lora"
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
self.model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=self.bnb_config,
device_map="auto", # Auto selects device to put model on.
)
self.model.config.use_cache = False
self.inference_model = PeftModel.from_pretrained(self.model, lora_adaptor, from_transformers=True)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
INTRO = "Below is a conversation between a user and you."
END = "Instruction: Write a response appropriate to the conversation."
prompt = "<user>:"
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
prompt = prompt+inputs
# preprocess
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = self.tokenizer(INTRO+'\n '+prompt+'\n '+END +'\n <assistant>:', return_tensors="pt").to(device)
inputs = {k: v.to('cuda') for k, v in inputs.items()}
output = self.inference_model.generate(input_ids=inputs["input_ids"],pad_token_id=self.tokenizer.pad_token_id, max_new_tokens=100, do_sample=True, temperature=0.1, top_p=0.9, repetition_penalty=1.5)
reply = self.tokenizer.batch_decode(output.detach().cpu().numpy(), skip_special_tokens=True)
return [{"generated_reply": reply}]