dordonezc's picture
Update handler.py
cc645d3 verified
raw
history blame contribute delete
No virus
929 Bytes
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from typing import Dict, List, Any
class EndpointHandler():
def __init__(self, path=""):
self.quant = quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True)
self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", quantization_config=self.quant, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pipeline = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)
def __call__(self, data: Dict[str, Dict[str, Any]]) -> Any:
inputs = data["inputs"]["msg"]
parameters = data["args"]
prediction = self.pipeline(inputs, **parameters)
output = prediction[0]['generated_text']
return output