import json import os from typing import Dict, List, Any import torch from transformers import AutoModelForCausalLM, AutoTokenizer PROMPT_FORMAT= """ <|user|> {inputs} <|end|> <|assistant|> """ class EndpointHandler(): def __init__(self, data): cfg = { "repo": "MrOvkill/Phi-3-Instruct-Bloated", } self.model = AutoModelForCausalLM.from_pretrained(cfg['repo'], trust_remote_code=True, torch_dtype=torch.float16) self.tokenizer = AutoTokenizer.from_pretrained(cfg['repo']) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: inputs = data.pop("inputs", "Q: What is the chemical composition of common concrete in 2024?\nA: ") max_new_tokens = 1024 if "max_new_tokens" in data: max_new_tokens = data.pop("max_new_tokens") max_new_tokens = int(max_new_tokens) try: max_new_tokens = int(max_new_tokens) except Exception as e: return json.dumps({ "status": "error", "reason": "max_length was passed as something that was absolutely not a plain old int" }) res = PROMPT_FORMAT.format(do_sample=False) retrurn model( res, max_new_tokens=max_new_tokens, do_sample=False ) return res