import json import os from typing import Dict, List, Any import torch from transformers import pipeline PROMPT_FORMAT= """ <|user|> {inputs} <|end|> <|assistant|> """ class EndpointHandler(): def __init__(self, data): cfg = { "repo": "MrOvkill/Phi-3-Instruct-Bloated", } self.pipe = pipeline("text-generation", "MrOvkill/Phi-3-Instruct-Bloated", torch_dtype=torch.float16, trust_remote_code=True) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ self.pipe = pipeline("text-generation", "MrOvkill/Phi-3-Instruct-Bloated", torch_dtype=torch.float16, trust_remote_code=True) max_new_tokens = 1024 if "max_new_tokens" in data: max_new_tokens = data["max_new_tokens"] try: max_new_tokens = int(max_new_tokens) except Exception as e: return json.dumps({ "status": "error", "reason": "max_length was passed as something that was absolutely not a plain old int" }) res = PROMPT_FORMAT.format(inputs=data['inputs']) return self.pipe( res, do_sample=False, max_new_tokens=max_new_tokens ) return res