Andrewwwwww's picture
Update handler.py
163e45f verified
raw
history blame contribute delete
No virus
1.28 kB
from typing import Dict, List, Any
import torch
from modelscope import AutoTokenizer
from modelscope import AutoModelForCausalLM
device = "cuda" # the device to load the model onto
class EndpointHandler:
def __init__(self, path=""):
self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map=device)
self.tokenizer = AutoTokenizer.from_pretrained(path)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
sys_prompt=data["prompt"]
list=data["inputs"]
prompt=f"<|im_start|>system\n{sys_prompt}.<|im_end|>\n"
for item in list:
if item["role"]=="assistant":
content=item["content"]
prompt+=f"<|im_start|>assistant\n{content}<|im_end|>\n"
else:
content=item["content"]
prompt+=f"<|im_start|>user\n{content}<|im_end|>\n"
prompt+="<|im_start|>assistant\n"
encodeds = self.tokenizer.encode(prompt, return_tensors="pt")
model_inputs = encodeds.to(device)
self.model.to(device)
generated_ids = self.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
decoded = self.tokenizer.decode(generated_ids[0])
return decoded