import torch import transformers from typing import Dict, Any from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 class EndpointHandler: def __init__(self, model_path: str = ""): tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, return_dict=True, device_map='auto', load_in_8bit=True, torch_dtype=dtype, trust_remote_code=True) self.pipeline = transformers.pipeline( "text-generation", model=model, tokenizer=tokenizer, temperature=0.8, repetition_penalty=1.1, max_new_tokens=1000, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: prompt = data.pop("inputs", data) llm_response = self.pipeline( prompt, return_full_text=False ) return llm_response[0]['generated_text'].strip()