|
import json |
|
import os |
|
from typing import Dict, List, Any |
|
import torch |
|
from transformers import pipeline |
|
|
|
PROMPT_FORMAT= """ |
|
<|user|> |
|
{inputs} <|end|> |
|
<|assistant|> |
|
""" |
|
|
|
class EndpointHandler(): |
|
def __init__(self, data): |
|
cfg = { |
|
"repo": "MrOvkill/Phi-3-Instruct-Bloated", |
|
} |
|
self.pipe = pipeline("text-generation", "MrOvkill/Phi-3-Instruct-Bloated", torch_dtype=torch.float16, trust_remote_code=True) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
self.pipe = pipeline("text-generation", "MrOvkill/Phi-3-Instruct-Bloated", torch_dtype=torch.float16, trust_remote_code=True) |
|
max_new_tokens = 1024 |
|
if "max_new_tokens" in data: |
|
max_new_tokens = data["max_new_tokens"] |
|
try: |
|
max_new_tokens = int(max_new_tokens) |
|
except Exception as e: |
|
return json.dumps({ |
|
"status": "error", |
|
"reason": "max_length was passed as something that was absolutely not a plain old int" |
|
}) |
|
|
|
res = PROMPT_FORMAT.format(inputs=data['inputs']) |
|
return self.pipe( |
|
res, |
|
do_sample=False, |
|
max_new_tokens=max_new_tokens |
|
) |
|
|
|
return res |