shankz7's picture
Update handler.py
17db25e verified
from unsloth import FastLanguageModel
from peft import PeftModel
class EndpointHandler():
def __init__(self, path="."):
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="mistralai/Mistral-7B-Instruct-v0.2", # Supports Llama, Mistral - replace this!
max_seq_length=2048,
dtype=None,
load_in_4bit=True,
)
model = PeftModel.from_pretrained(model, "SITG/custsvc_entityextract_mistralv0.2instruct")
self.model = model
self.model.eval()
self.device_map = "cuda" # the device to load the model onto
self.tokenizer = tokenizer
def __call__(self, data: any) -> any:
inputs = data.pop("inputs", data)
if len(inputs) == 0:
raise ValueError("prompt cannot be empty")
inputs=inputs+"\n### Response:\n"
model_input = self.tokenizer(inputs, return_tensors="pt").to(self.device_map)
output = self.model.generate(input_ids=model_input["input_ids"].to(self.device_map),
use_cache=False,
temperature=0.1, top_k=1, top_p=1.0, repetition_penalty=1.4,
max_new_tokens=256,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
num_beams=1,
num_return_sequences=1)
output = self.tokenizer.decode(output[0])
result = (output
.split(self.tokenizer.eos_token)[0]
.split("Response:")[1]
.strip()
.split("###")[0]
.replace("```json", "")
.replace("```", ""))
return {"response": result}