Ngit's picture
Update handler.py
0821eba
raw
history blame contribute delete
No virus
1.76 kB
import torch
from typing import Dict, List, Any
from transformers import AutoTokenizer, BitsAndBytesConfig
from peft import AutoPeftModelForCausalLM
def parse_output(text):
marker = "### Response:"
if marker in text:
pos = text.find(marker) + len(marker)
else:
pos = 0
return text[pos:].replace("<pad>", "").replace("</s>", "").strip()
class EndpointHandler:
def __init__(self, path="./", use_bnb=True):
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
self.model = AutoPeftModelForCausalLM.from_pretrained(
path, load_in_8bit=False, quantization_config=bnb_config, device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(path)
print("Memory footprint: ", self.model.get_memory_footprint())
print("Device map: ", self.model.hf_device_map)
def __call__(self, data: Any) -> List[List[Dict[str, str]]]:
inputs = data.get("inputs", data)
prompt = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction: \n{inputs}\n\n### Response: \n"
parameters = data.get("parameters", {})
with torch.no_grad():
inputs = self.tokenizer(
prompt, return_tensors="pt", return_token_type_ids=False
).to(self.model.device)
outputs = self.model.generate(**inputs, **parameters)
return {
"generated_text": parse_output(
self.tokenizer.decode(outputs[0].tolist(), skip_special_tokens=True)
)
}